코드 분석/Edge-Connect

Edge-Connect: main.py

상솜공방 2024. 1. 25. 14:49
import os
import cv2
import random
import numpy as np
import torch
import argparse
from shutil import copyfile
from src.config import Config # yml 파일을 dict로 변환하여 가지고 있는 객체.
from src.edge_connect import EdgeConnect


def main(mode=None):
    r"""starts the model

    Args:
        mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified
    """
    # train의 경우 mode = 1로 설정된 뒤 main 함수가 호출된다.
    # mode, args, yml, 이 세 가지의 정보를 모두 dict로 변환하여 가지고 있는 config 객체를 생성.
    config = load_config(mode)


    # cuda visble devices
    # cofig.GPU 리스트에 있는 GPU 환경 변수를 받아와 사용할 GPU 장치를 설정한다.
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU)


    # init device
    if torch.cuda.is_available():
        #  CUDA가 사용 가능하면, 디바이스를 CUDA로 설정
        config.DEVICE = torch.device("cuda")
        # cuDNN 자동 튜너를 활성화하여, 실행 시간이 빠른 알고리즘을 자동으로 선택하도록
        torch.backends.cudnn.benchmark = True   # cudnn auto-tuner
    else:
        config.DEVICE = torch.device("cpu")



    # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader)
    # OpenCV가 사용하는 스레드 수를 1로 설정하여, PyTorch 데이터 로더와의 데드락을 방지
    cv2.setNumThreads(0)


    # initialize random seed
    torch.manual_seed(config.SEED) # PyTorch의 난수 생성기에 시드를 설정
    torch.cuda.manual_seed_all(config.SEED) # 모든 CUDA 장치에 대한 난수 생성기 시드를 설정
    np.random.seed(config.SEED) # NumPy의 난수 생성기에 시드를 설정
    random.seed(config.SEED) # Python 내장 random 모듈의 난수 생성기에 시드를 설정


    # build the model and initialize
    model = EdgeConnect(config) # config 정보를 넣고 EdgeConnect 모델을 초기화
    model.load() # 모델의 가중치를 로드


    # model training
    if config.MODE == 1:
        config.print() # config._yml 파일을 콘솔에 출력
        print('\nstart training...\n')
        model.train()

    # model test
    elif config.MODE == 2:
        print('\nstart testing...\n')
        model.test()

    # eval mode
    else:
        print('\nstart eval...\n')
        model.eval()


# input: mode, output: config
# mode 정보와 parser로 유저의 정보를 받고, yml 파일에 있는 정보로 config 객체를 만든 다음, parser의 정보를 추가적으로 저장한 뒤 출력한다.
def load_config(mode=None):
    r"""loads model config

    Args:
        mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified
    """
    # 명령줄 인수를 파싱하기 위한 ArgumentParser 객체를 생성
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '--checkpoints', type=str, default='./checkpoints', help='model checkpoints path (default: ./checkpoints)')
    parser.add_argument('--model', type=int, choices=[1, 2, 3, 4], help='1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model')

    # test mode
    if mode == 2:
        parser.add_argument('--input', type=str, help='path to the input images directory or an input image')
        parser.add_argument('--mask', type=str, help='path to the masks directory or a mask file')
        parser.add_argument('--edge', type=str, help='path to the edges directory or an edge file')
        parser.add_argument('--output', type=str, help='path to the output directory')

    args = parser.parse_args() # 명령줄 인수를 파싱하여 args 객체에 저장
    # path에 설정 yml 파일의 경로를 결정
    # 해당 디렉토리에 chekcpoint도 생성된다.
    config_path = os.path.join(args.path, 'config.yml')

    # create checkpoints path if does't exist
    if not os.path.exists(args.path):
        os.makedirs(args.path)

    # copy config template if does't exist
    # 만약 yml 파일을 넣어두지 않았다면, example 파일의 탬플릿을 복사해서 설정 파일을 생성한다.
    if not os.path.exists(config_path):
        copyfile('./config.yml.example', config_path)

    # load config file
    # 그 후, 해당 yml 파일을 받아와 Config 클래스에 넣고, 이를 바탕으로 config 인스턴스를 생성한다.
    # Config() 클래스는 모델의 정보와 훈련 여부 등, 모든 옵션이 적힌 yml 파일을 받은 뒤, 이를 딕셔너리 형태로 변환해준다.
    # 그리고 그 정보는 인스턴스 변수인 self._dict에 모두 저장된다.
    # 예를 들어, config.MODE, MODEL, MASK 등등으로 접근할 수 있다.
    config = Config(config_path)

    # train mode
    if mode == 1:
        config.MODE = 1
        if args.model: # model에 대한 정보가 존재한다면
            config.MODEL = args.model # 그 정보를 그대로

    # test mode
    elif mode == 2:
        config.MODE = 2
        # 모델에 대한 정보가 없다면 3(edge-inpaint model)로 세팅한다.
        config.MODEL = args.model if args.model is not None else 3
        config.INPUT_SIZE = 0 # 0이면 원본 이미지 크기

        if args.input is not None: # 인풋 이미지 경로를 넣어줬다면
            config.TEST_FLIST = args.input # 이를 TEST_FLIST에 복사

        if args.mask is not None: # 마스크 이미지 경로를 넣어줬다면
            config.TEST_MASK_FLIST = args.mask # 이를 TEST_MASK_FLIST에 복사

        if args.edge is not None: # 에시 파일 경로를 넣어줬다면
            config.TEST_EDGE_FLIST = args.edge # 이를 TEST_EDGE_FLIST에 복사

        if args.output is not None: # 아웃풋 이미지의 경로를 넣어줬다면
            config.RESULTS = args.output # 이를 RESULTS에 복사

    # eval mode
    elif mode == 3:
        config.MODE = 3
        # # 모델에 대한 정보가 없다면 3(edge-inpaint model)로 세팅한다.
        config.MODEL = args.model if args.model is not None else 3

    return config


if __name__ == "__main__":
    main()