import osimport globimport torchimport matplotlib.pyplot as pltfrom tqdm import tqdmfrom torch.utils.data import DataLoaderfrom src.utils import imsavefrom src.dataset import Datasetfrom src.metrics import PSNR, EdgeAccuracyfrom src.models import EdgeModel, InpaintingModelclass SGGLAT(): def __init__(self, args): self.args = args if self.args.model == 1: model_name = ..
import osimport torchimport torch.nn as nnimport torch.optim as optimimport matplotlib.pyplot as pltfrom .generator import EdgeGenerator, ImageGeneratorfrom .edge_connect_discriminator import Discriminatorfrom .loss import AdversarialLoss, PerceptualLoss, StyleLossclass BaseModel(nn.Module): def __init__(self, name, args): super(BaseModel, self).__init__() self.name = name ..
import torchimport argparsedef get_args(): parser = argparse.ArgumentParser() # 실험 세팅 제어 parser.add_argument('--mode', type=int, default = 1, choices=[1, 2], help='1: train, 2:test') parser.add_argument('--model', type=int, default = 2, choices=[1, 2, 3, 4], help='1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model') parser.add_argument('--mask', type=int, defa..
import torchimport randomimport numpy as npimport os, cv2, globclass Dataset(torch.utils.data.Dataset): def __init__(self, args, image_flist, edge_flist, mask_flist, augment=True, training=True): # args: 다양한 모델 옵션을 딕셔너리로 저장한 객체 # image_flist, edge_flist, mask_flist: 이미지, 엣지, 마스크 파일이 들어있는 폴더까지의 경로 # augment: 데이터 증강 여부 # training: 훈련 모드 여부 super(Dataset, self)..