코드 분석/CMT

CMT: utils.py

상솜공방 2024. 2. 2. 16:08

코드 독해에 필요한 배경지식

더보기

torch.load() 함수
torch.load() 함수는 PyTorch에서 모델이나 텐서 등을 저장한 파일을 로드할 때 사용됩니다. 이 함수는 저장된 객체를 직렬화된 형태에서 다시 Python 객체로 복원합니다. 주로 모델의 가중치, 옵티마이저의 상태 등을 저장한 체크포인트 파일을 로드하는 데 사용됩니다. 예를 들어, 학습 중에 모델의 상태를 파일에 저장했다면, 이후에 torch.load()를 사용하여 해당 상태를 다시 로드할 수 있습니다.

 

load_state_dict() 함수와 strict=True
load_state_dict() 함수는 모델의 매개변수(가중치와 편향)를 로드하기 위해 사용됩니다. 이 함수는 인자로 전달된 상태 딕셔너리(state_dict)를 현재 모델의 매개변수와 매핑합니다.
strict=True 옵션을 사용하면, 로드하려는 상태 딕셔너리와 모델의 매개변수 사이에 완벽한 일치가 필요합니다. 즉, 상태 딕셔너리에 있는 모든 키가 모델에 존재해야 하며, 모델에 있는 모든 키가 상태 딕셔너리에도 있어야 합니다. 이 옵션이 False일 경우, 일치하지 않는 키가 있어도 오류가 발생하지 않고 무시됩니다. 하지만 일반적으로 모델을 정확히 복원하기 위해 strict=True를 사용합니다.

체크포인트의 데이터
체크포인트 파일은 학습된 모델의 매개변수(가중치와 편향), 옵티마이저의 상태, 그리고 필요에 따라 학습 중에 추적한 다른 메타데이터(예: 에포크 수, 최고 성능 지표 등)를 포함할 수 있습니다. 체크포인트를 통해 학습 과정을 중단했다가 다시 시작할 수 있으며, 학습된 모델을 다른 환경으로 이전하거나 배포할 수 있습니다.
state_dict: 모델의 매개변수를 포함하는 딕셔너리입니다.
optimizer: 옵티마이저의 상태를 포함하는 딕셔너리입니다. 학습률과 같은 옵티마이저 설정, 누적된 모멘텀 등을 포함할 수 있습니다.
추가적으로 사용자가 정의한 다른 값들도 포함될 수 있습니다. 예를 들어, disc 키는 판별자의 상태를 저장할 때 사용될 수 있습니다(주로 GANs에서 사용).

import torch
import numpy as np

def _load(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    # checkpoint: 모델 파라미터, 옵티마이저 상태 외의 추가적인 정보(에포크 수, 최고 성능 지표)등을 포함할 수 있다.
    # 해당 파일을 통해 학습 과정을 중단했다가 다시 시작하거나, 학습된 모델을 배포할 수 있다.
    return checkpoint

def load_checkpoint(path, model, optimizer=None, reset_optimizer=True, is_dis=False):
    # path: 체크포인트 파일 경로
    # model: 업데이트할 모델
    # optimizer: 업데이트 할 옵티마이저
    # reset_optimizer: 옵티마이저 리셋 여부
    # is_dis: 판별자의 체크포인트를 로드할지 여부

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)

    if is_dis:
        s = checkpoint["disc"] # 판별기의 상태를 로드
    else:
        s = checkpoint["state_dict"] # 모델의 상태를 로드
    
    new_s = {}
    # 체크포인트에서 로드한 상태 딕셔너리 s의 각 항목을 순회하며, 키 이름에서 'module.' 접두사를 제거한 새로운 상태 딕셔너리 new_s를 생성
    # 분산 학습을 위해 torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel을 사용하면 모델의 각 모듈 앞에 module.이란 접두사가 추가됨
    # 이를 단일 GPU에서 로드하기 위해 제거
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s, strict=True) # load_state_dict()는 모델의 파라미터(가중치와 편향)을 로드하는데 사용된다.
    # 상태 딕셔너리로 전달받은 정보를 모델의 파라미터에 매핑한다.
    # strict=True 옵션을 적용하면, 상태 딕셔너리와 모델 매개변수 사이에 완벽한 일치가 필요하다.

    if not reset_optimizer: # 옵티마이저를 리셋하지 않으면
        optimizer_state = checkpoint["optimizer"] # 옵티마이저의 상태를 변수에 저장
        if optimizer_state is not None: # 상태가 로드 되었다면
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"]) # 옵티마이저의 상태를 로드
    return model

def psnr(img1, img2):
    mse = np.mean((img1-img2)** 2) # 두 이미지간의 MSE(Mean Squared Error)를 계산
    if mse == 0: # mse가 0이면, 두 이미지는 같은 이미지이므로 100을 반환
        return 100
    PIXEL_MAX = 255.0 # 이미지 픽셀이 가질 수 있는 최대값
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) # MSE가 0이 아닌 경우, PSNR을 계산하여 반환, 이 때 값이 높을수록 두 이미지간의 차이는 적다.