import torchimport numpy as npfrom torch import nnimport torch.nn.functional as Ffrom torch_utils import misc# d_01과 다르게 쓰이지 않는 라이브러리를 지우고 코드가 동작하게끔 바꾸었다.# 그러나 모델 성능이 떨어지거나 사소한 오류가 발생한다면 d_01을 사용하도록 하자.def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} return NF[2 ** stage]def normalize_2nd_mo..
import torchfrom torch import nn#============================================================= Basic Module =============================================================#class MaskUpdator(nn.Module): def __init__(self, patch_size): super(MaskUpdator, self).__init__() self.patch_size = patch_size def forward(self, x): x = self.mask_to_patches(x) x = self.remove_m..