나만의 머신러닝 코드를 작성하기 위해선 기존에 완성된 코드를 분석하는 것이 우선이라고 생각하였다. 따라서 2019년에 발표된 edge connect 논문의 코드를 먼저 깊게 분석하기로 한다. 메인 함수 더보기 main.py 개요 모델을 불러와서 argument를 넣고 train 혹은 test를 진행하는 함수 보조 모듈 config.py, edge_conncect.py 함수 main(mode = None) 1. arguements 받기 (import config.py) 2. CUDA 설정 3. 랜덤 시드 설정 4. 모델 초기화 (import edge_connect.py) 5. train/test/eval에 따른 실험 진행 load_config(mode = None) config.yml 파일을 받아서 arg..
import os import numpy as np import torch from torch.utils.data import DataLoader from .dataset import Dataset from .models import EdgeModel, InpaintingModel from .utils import Progbar, create_dir, stitch_images, imsave from .metrics import PSNR, EdgeAccuracy class EdgeConnect(): def __init__(self, config): self.config = config # config.MODEL에 따라 어떤 것을 수행할지 결정. if config.MODEL == 1: model_name =..
코드를 이해하기 위한 사전 지식 더보기 nn.DataParallel() 함수 nn.DataParallel은 PyTorch에서 제공하는 함수로, 모델을 여러 GPU에 분산시켜 병렬로 학습할 수 있게 해줍니다. 이 함수의 주요 특징은 다음과 같습니다: 모델의 복사본을 여러 GPU에 분산시킵니다. 각 GPU에서는 데이터의 서브셋(subset)에 대해 연산을 수행합니다. 모든 GPU에서의 연산 결과는 자동으로 합쳐져 최종 결과를 생성합니다. 이를 통해 대규모 데이터셋과 복잡한 모델을 더 빠르게 학습할 수 있습니다. self.add_module() 함수와 모델 초기화 방식 self.add_module('generator', generator)는 nn.Module 클래스의 메서드로, 모델의 서브모듈을 동적으로 추가..
코드를 이해하기 위해 필요한 지식 더보기 특수 메서드 __len__과 __getitem__ __len__ 메서드: __len__은 컨테이너 타입의 객체(예: 리스트, 튜플, 세트, 사용자 정의 컬렉션 등)의 길이를 반환하는 특수 메서드입니다. 예를 들어, len(obj)를 호출하면 내부적으로 obj.__len__()이 호출됩니다. PyTorch의 Dataset 클래스에서 __len__은 데이터셋의 총 항목 수를 반환합니다. __getitem__ 메서드: __getitem__은 컨테이너 타입의 객체에서 특정 인덱스의 항목을 가져오는 특수 메서드입니다. 예를 들어, obj[index]를 호출하면 내부적으로 obj.__getitem__(index)가 호출됩니다. PyTorch의 Dataset 클래스에서 __g..
import os import sys import time import random import numpy as np import matplotlib.pyplot as plt from PIL import Image # 디렉토리를 받아 폴더를 만들어주는 함수 def create_dir(dir): if not os.path.exists(dir): os.makedirs(dir) # def create_mask(width, height, mask_width, mask_height, x=None, y=None): mask = np.zeros((height, width)) # ex) 512x512가 0으로 채워진 이미지 mask_x = x if x is not None else random.randint(0, wi..
import os import cv2 import random import numpy as np import torch import argparse from shutil import copyfile from src.config import Config # yml 파일을 dict로 변환하여 가지고 있는 객체. from src.edge_connect import EdgeConnect def main(mode=None): r"""starts the model Args: mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified """ # train의 경우 mode = 1로 설정된 뒤 main 함수가 호출된다. # mode, a..