디시인사이드 갤러리

마이너 갤러리 이슈박스, 최근방문 갤러리

갤러리 본문 영역

[정보/뉴스] KAN Layer 핵심 코드앱에서 작성

초존도초갤로그로 이동합니다. 2024.05.02 09:36:49
조회 326 추천 2 댓글 11
														

class KANLayer(nn.Module):
    def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, device='cpu'):
        super(KANLayer, self).__init__()
        # size 
        self.size = size = out_dim * in_dim
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.num = num
        self.k = k

        # shape: (size, num)
        self.grid = torch.einsum('i,j->ij', torch.ones(size, ), torch.linspace(grid_range[0], grid_range[1], steps=num + 1))
        noises = (torch.rand(size, self.grid.shape[1]) - 1 / 2) * noise_scale / num
        noises = noises.to(device)
        # shape: (size, coef)
        self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k))
        if isinstance(scale_base, float):
            self.scale_base = torch.nn.Parameter(torch.ones(size, ) * scale_base).requires_grad_(sb_trainable)  # make scale trainable
        else:
        self.scale_sp = torch.nn.Parameter(torch.ones(size, ) * scale_sp).requires_grad_(sp_trainable)  # make scale trainable
        self.base_fun = base_fun

        self.mask = torch.nn.Parameter(torch.ones(size, )).requires_grad_(False)
        self.grid_eps = grid_eps
        self.weight_sharing = torch.arange(size)
        self.lock_counter = 0
        self.lock_id = torch.zeros(size)
        self.device = device

    def forward(self, x):
        batch = x.shape[0]
        # x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
        x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
        preacts = x.permute(1, 0).clone().reshape(batch, self.out_dim, self.in_dim)
        base = self.base_fun(x).permute(1, 0)  # shape (batch, size)
        y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k, device=self.device)  # shape (size, batch)
        y = y.permute(1, 0)  # shape (batch, size)
        postspline = y.clone().reshape(batch, self.out_dim, self.in_dim)
        y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
        y = self.mask[None, :] * y
        postacts = y.clone().reshape(batch, self.out_dim, self.in_dim)
        y = torch.sum(y.reshape(batch, self.out_dim, self.in_dim), dim=2)  # shape (batch, out_dim)
        # y shape: (batch, out_dim); preacts shape: (batch, in_dim, out_dim)
        # postspline shape: (batch, in_dim, out_dim); postacts: (batch, in_dim, out_dim)
        # postspline is for extension; postacts is for visualization
        return y, preacts, postacts, postspline

    def update_grid_from_samples(self, x):
        batch = x.shape[0]
        x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
        x_pos = torch.sort(x, dim=1)[0]
        y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device)
        num_interval = self.grid.shape[1] - 1
        ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
        grid_adaptive = x_pos[:, ids]
        margin = 0.01
        grid_uniform = torch.cat([grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in np.linspace(0, 1, num=self.grid.shape[1])], dim=1)
        self.grid.data = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device)

    def initialize_grid_from_parent(self, parent, x):
        batch = x.shape[0]
        # preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
        x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
        x_pos = parent.grid
        sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0.).to(self.device)
        sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1)
        y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k, device=self.device)
        percentile = torch.linspace(-1, 1, self.num + 1).to(self.device)
        self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0)
        self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k, self.device)

    def get_subset(self, in_id, out_id):
        spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun)
        spb.grid.data = self.grid.reshape(self.out_dim, self.in_dim, spb.num + 1)[out_id][:, in_id].reshape(-1, spb.num + 1)
        spb.coef.data = self.coef.reshape(self.out_dim, self.in_dim, spb.coef.shape[1])[out_id][:, in_id].reshape(-1, spb.coef.shape[1])
        spb.scale_base.data = self.scale_base.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, )
        spb.scale_sp.data = self.scale_sp.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, )
        spb.mask.data = self.mask.reshape(self.out_dim, self.in_dim)[out_id][:, in_id].reshape(-1, )

        spb.in_dim = len(in_id)
        spb.out_dim = len(out_id)
        spb.size = spb.in_dim * spb.out_dim
        return spb

    def lock(self, ids):
        self.lock_counter += 1
        # ids: [[i1,j1],[i2,j2],[i3,j3],...]
        for i in range(len(ids)):
            if i != 0:
                self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[0][1] * self.in_dim + ids[0][0]
            self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = self.lock_counter

    def unlock(self, ids):
        # check ids are locked
        num = len(ids)
        locked = True
        for i in range(num):
            locked *= (self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] == self.weight_sharing[ids[0][1] * self.in_dim + ids[0][0]])
        if locked == False:
            print("they are not locked. unlock failed.")
            return 0
        for i in range(len(ids)):
            self.weight_sharing[ids[i][1] * self.in_dim + ids[i][0]] = ids[i][1] * self.in_dim + ids[i][0]
            self.lock_id[ids[i][1] * self.in_dim + ids[i][0]] = 0
        self.lock_counter -= 1








결론 : 죤내 복잡함

- dc official App
자동등록방지

추천 비추천

2

고정닉 2

댓글 영역

전체 댓글 0
등록순정렬 기준선택
본문 보기

하단 갤러리 리스트 영역

왼쪽 컨텐츠 영역

갤러리 리스트 영역

갤러리 리스트
번호 말머리 제목 글쓴이 작성일 조회 추천
2863 설문 시세차익 부러워 부동산 보는 눈 배우고 싶은 스타는? 운영자 24/05/27 - -
473894 정보/ 이제야 이해되는 그 날의 진실 [2] ㅇㅇ(119.77) 05.15 256 1
473893 일반 저새끼들이 책임감 꺼내는 거 존나 웃기네 ㅋㅋㅋㅋㅋㅋ ㅇㅇ갤로그로 이동합니다. 05.15 45 1
473892 일반 개발휴가 검열 염병 좆지랄할 시간에 개발이나 더하지 [1] ㅇㅇ갤로그로 이동합니다. 05.15 52 0
473891 일반 구글이 잘하는거 =검열 [1] ㅇㅇ갤로그로 이동합니다. 05.15 74 0
473890 일반 성능도 좆박아 글도못써 검열도 심해 pc도묻어 [1] ㅇㅋ갤로그로 이동합니다. 05.15 112 1
473889 일반 GPT5는 지금 오픈AI연구원 1인분역할은한다던데 ㅇㅇ(118.34) 05.15 113 0
473888 일반 걍 구글 해체해서 앤쓰로픽 메타 oai한테 나눠주자 ㅇㅇ ㅇㅇ갤로그로 이동합니다. 05.15 62 0
473887 일반 회사가 큰 건 알겠는데 [2] 빙냥이ㄱㅇㅇ갤로그로 이동합니다. 05.15 120 0
473886 일반 아니 1.5 울트라 진짜 없다고? [2] ㅇㅇ갤로그로 이동합니다. 05.15 142 0
473885 일반 진짜 인도 등판 ㅋㅋㅋㅋ TS망상갤로그로 이동합니다. 05.15 48 0
473884 정보/ 오늘 공개된 '제미니 1.5 플래쉬'의 혁신성 [5] ㅇㅇ(119.77) 05.15 533 16
473883 일반 구글 이씨발련들은 영상찍느라 여행만 주구장창 다녔겠노 ㅋㅋ ㅇㅇ갤로그로 이동합니다. 05.15 49 0
473882 일반 이제 애플워치에 지피티달아서 자비스 보여주는 부분이냐? ㅇㅇ갤로그로 이동합니다. 05.15 103 0
473881 일반 이제 구글 실드충도 안보이네 ㅇㅋ갤로그로 이동합니다. 05.15 45 0
473880 일반 구글이 초칠 ai 시장을 oai가 막아준 거네 ㅅㅂ ㅇㅇ갤로그로 이동합니다. 05.15 80 2
473879 일반 gpt 보이스 모델 지금 되는데? 누가 안된다했어 [5] ㅇㅇ(116.123) 05.15 163 0
473878 일반 나 StockTrading 끊었었는데 FIREKICK갤로그로 이동합니다. 05.15 59 0
473877 일반 구글 이새끼들은 이름 좆같이 붙여서 몇개를 내는거냐 ㅇㅇ갤로그로 이동합니다. 05.15 49 0
473876 일반 갈아탄다 [6] ㅇㅇ(211.234) 05.15 263 0
473875 일반 토큰웅앵웅 씹년아 고졸한테 전공서적 던져주면 설명가능하냐? ㅇㅇ갤로그로 이동합니다. 05.15 63 1
473874 일반 오픈모델 공개 indie갤로그로 이동합니다. 05.15 48 0
473873 일반 어차피 구글 별기대안했음 [1] ㅇㅋ갤로그로 이동합니다. 05.15 65 0
473872 일반 어제 OpenAI는 일부러 딸깍한거 맞는거 같노 [8] ㅇㅇ갤로그로 이동합니다. 05.15 260 8
473871 일반 내일뉴스제목) 구글, OAI에 맞대응...책 300권 분량 업로드 가능 [1] ㅇㅇ갤로그로 이동합니다. 05.15 137 0
473870 일반 진짜 존나 느리네 ㅋㅋㅋㅋㅋ [1] ㅇㅇ갤로그로 이동합니다. 05.15 83 0
473869 일반 하도 구글 억까하길래 궁금했는데 좀 이해했음 [49] ㅇㅇ(121.140) 05.15 404 0
473868 일반 옆 가게에서는 신선한 사과가 무료지만 우리의 썩은 사과를 사세요 [1] ㅇㅇ(14.36) 05.15 74 0
473866 일반 gpt=goat 클로드=글이라도 잘씀 [3] ㅇㅇ갤로그로 이동합니다. 05.15 124 0
473865 일반 Oai는 걍 구글 싫어하는듯 견제조차 못됨 [1] ㅇㅇ(121.131) 05.15 119 2
473863 일반 플래시 api 바로 열렸노 ㅋㅋㅋㅋㅋㅋ [1] ㅇㅇ갤로그로 이동합니다. 05.15 125 0
473862 일반 본 거 또 보고 TS망상갤로그로 이동합니다. 05.15 31 0
473861 일반 하사비스 속마음 : ㅅㅂ OAI로 가고 싶다.. ㅇㅇ(58.124) 05.15 65 0
473860 일반 사만다에 음성학습 모델 먹일수만 있으면 소원이 없겠노 ㅇㅇ갤로그로 이동합니다. 05.15 33 0
473858 일반 토큰 얘기만 2m번 한듯 ㅇㅇ(218.39) 05.15 40 0
473857 일반 1.5 pro를 누가 돈주고 쓰냐고 ㅋㅋㅋㅋㅋ TS망상갤로그로 이동합니다. 05.15 80 2
473856 일반 ㅋㅋㅋ io 포장지만 번지르르~~ ㅇㅇ(14.36) 05.15 27 0
473855 일반 이새끼들 자꾸 콘텍스트로 딸딸이 치는거 개좆같네 씨밯 [1] ㅇㅇ(106.101) 05.15 83 0
473854 일반 그놈의 토큰 2m 지랄좀 그만 제발!!!! ㅇㅇ(182.212) 05.15 61 0
473853 일반 oai가 4o 내놨으면 1.5울트라도 줘야지 좆글 씹년아 [1] ㅇㅇ갤로그로 이동합니다. 05.15 71 1
473852 일반 하사비스게이 ㄹㅇ 불쌍해서 어쩌노 ㅇㅇ갤로그로 이동합니다. 05.15 68 0
473851 일반 뭐지? 하나 더 내나? ㅇㅇ(221.155) 05.15 43 0
473850 일반 gpt4 거의 일리야 혼자서 만든거라던데 [5] ㅇㅋ갤로그로 이동합니다. 05.15 264 1
473849 일반 하 울트라 없는거 확정이네 ㅇㅇ(119.64) 05.15 30 0
473848 정보/ '초격차'란 이런 것 [1] ㅇㅇ(119.77) 05.15 207 2
473847 일반 울트라 찐막으로 없노 ㅋㅋㅋㅋㅋ TS망상갤로그로 이동합니다. 05.15 35 0
473846 일반 gpt4o 코딩성능은 ㅈㄴ올랐네 ㅇㅇ(118.34) 05.15 97 0
473845 일반 하사비스 나와서 빨리 원모어띵 해야 그나마 복구가능 ㅅㅂ ㅇㅇ갤로그로 이동합니다. 05.15 22 0
473844 일반 보이스 업뎃이 만약 구글 행사 종료 후 바로 된다면? [4] ㅇㅇ(218.39) 05.15 85 0
473843 일반 아무래도 아이폰으로 갈아타야 될 거 같은 특붕이는 개추 [3] ㅇㅇ갤로그로 이동합니다. 05.15 164 5
473842 일반 oai 검색엔진 빨리 만들어다오 ㅇㅇ갤로그로 이동합니다. 05.15 36 0
갤러리 내부 검색
제목+내용게시물 정렬 옵션

오른쪽 컨텐츠 영역

실시간 베스트

1/8

뉴스

디시미디어

디시이슈

1/2