코드 분석/MAT

MAT: generate_image.py

상솜공방 2023. 12. 26. 17:27

해당 파일은 MAT 논문의 테스트 코드인 generate_image.py이다.

한 줄씩 꼼꼼히 공부하며 해당 코드가 어떻게 동작하는지 주석으로 설명하였다.

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Generate images using pretrained network pickle."""
# pickle은 객체 자체를 저장했다 불러올 수 있는 편한 라이브러리
import cv2 # 대표적인 이미지 처리
import pyspng # PNG 이미지 처리에 사용
import glob # 디렉토리 순회하며 파일 찾음
import os # 운영체제와 상호작용하며 파일 경로를 받아옴
import re # 정규 표현식(패턴 매칭, 문자열 조작)에 사용
import random # 랜덤값 생성
from typing import List, Optional

import click # CLI 생성
import dnnlib # NVIDA의 GAN과 관련된 라이브러리
import numpy as np # 행렬 연산
import PIL.Image # 이미지 처리 작업
import torch # 파이토치
import torch.nn.functional as F # 파이토치

import legacy # 이전 버전의 모델이나 라이브러리를 지원할 때 사용
from datasets.mask_generator_512 import RandomMask # 512x512 크기의 랜덤 마스크 생성
from networks.mat import Generator # 생성자 클래스

def num_range(s: str) -> List[int]:
    '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
    # 문자열 s를 받아 정수 리스트 List[int]를 반환하는 함수. 화살표를 통해 input, output data를 지정하는 문법.
    # '1-5'를 입력하면 [1, 2, 3, 4, 5] 반환, '3, 7, 9'를 입력하면 [3, 7, 9] 반환.
    range_re = re.compile(r'^(\d+)-(\d+)$')
    m = range_re.match(s)
    if m:
        return list(range(int(m.group(1)), int(m.group(2))+1))
    vals = s.split(',')
    return [int(x) for x in vals]

# 두 모델의 구조가 다를 때에도 사전 학습 된 매개변수 및 버퍼를 가져오기 위해, 레이어의 이름을 key로 하여 복사하는 아래의 함수 세 개를 제안한다.
# 두 모델에서 동일한 이름을 가진 매개변수 또는 버퍼가 있다면, 해당 파라미터 또는 버퍼의 값을 복사하게 됩니다.
# 그러나 두 모델 사이의 구조가 완전히 다르면, 일부 매개변수 또는 버퍼는 복사되지 않을 것입니다.
# 예를 들어, 두 모델이 각각 conv1 및 conv2라는 동일한 이름의 컨볼루션 레이어를 가지고 있다면, 해당 레이어의 가중치와 편향이 복사될 것입니다.
# 그러나 모델 간의 일부 레이어가 서로 다른 이름을 가지고 있다면, 해당 레이어의 가중치와 편향은 복사되지 않을 것입니다.
def copy_params_and_buffers(src_module, dst_module, require_all=False):
    # 이 함수는 두 개의 PyTorch 신경망 모듈(src_module 및 dst_module)을 가져와서, src_module의 매개변수와 버퍼를 dst_module로 복사합니다. 
    # assert는 해당 조건문이 거짓인 경우 Assert error를 출력하는 함수.
    assert isinstance(src_module, torch.nn.Module) # src_module이 torch.nn.Module 클래스의 인스턴스인지 확인합니다. 만약 아니라면 AssertionError가 발생합니다.
    assert isinstance(dst_module, torch.nn.Module) # dst_module이 torch.nn.Module 클래스의 인스턴스인지 확인합니다. 만약 아니라면 AssertionError가 발생합니다.
    src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} # src_module에서 이름과 텐서로 이루어진 딕셔너리를 생성합니다.
    for name, tensor in named_params_and_buffers(dst_module): # dst_module의 각 매개변수와 버퍼에 대해 반복합니다.
        assert (name in src_tensors) or (not require_all)
        # src_tensor의 name이 dst_tensor의 name과 매칭이 안 되고, require_all이 False일 때에만 오류 메시지 출력.
        if name in src_tensors:
            tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
            # 소스 텐서를 목표 텐서에 복사하고 detach()함수를 사용하여 연결을 끊는다. requires_grad_는 소스 텐서의 것과 동일하게 세팅한다.


# 매개변수: 학습 중에 조절되는 가중치 ex) weight, bias. nn.Parameter 클래스의 인스턴스로, back propagation과 gradient descent를 통해 최적회 됨.
# 버퍼: 매개변수 이외의 것. 모델의 상태 등의 정보를 내포. ex) CNN 필터 중 학습 중에 변하지 않는 필터.
def params_and_buffers(module):
    assert isinstance(module, torch.nn.Module)
    return list(module.parameters()) + list(module.buffers())
    # 결과: [param1, param2, buffer1, buffer2, ...]
    # 매개변수와 버퍼를 각각의 객체로 리스트에 넣은 뒤 반환.

def named_params_and_buffers(module):
    # 모듈의 매개변수와 버퍼를 모두 포함하는 리스트를 반환.
    assert isinstance(module, torch.nn.Module)
    return list(module.named_parameters()) + list(module.named_buffers())
    # 결과: [(name1, param1), (name2, param2), (name3, buffer1), (name4, buffer2), ...]
    # 매개변수와 버퍼를 이름과 함께 튜플로써 리스트에 넣은 뒤 반환.

# argparse와 동일하게 CLI 인터페이스를 만들어주는 라이브러리이다.
@click.command() # 명령(Command)을 정의하는 데 사용되는 데코레이터(decorator)입니다. 이 명령은 해당 스크립트를 실행할 때 사용자가 제공해야 하는 옵션과 인자들을 정의합니다.
@click.pass_context # ctx라는 컨텍스트 객체를 생성해 아래의 option을 하나씩 추가한다.
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) # 네트워크 피클의 파일 경로. required=True는 필수 정보라는 의미.
@click.option('--dpath', help='the path of the input image', required=True) # 인풋 이미지 경로
@click.option('--mpath', help='the path of the mask') # 마스크 경로
@click.option('--resolution', type=int, help='resolution of input image', default=512, show_default=True) # 이미지 해상도
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) # 부동 소수점 truncation psi 값.
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) # 노이즈 모드
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
# random noise: 각 픽셀마다 랜덤한 노이즈를 적용해 다양한 이미지를 생성할 수 있다.
# const noise: 모든 픽셀에 동일한 값의 노이즈를 적용해 일관된 패턴을 부여하거나, 특정 효과를 확인할 수 있다.


# 명령어 예)
# python generate_image.py --network pretrained/CelebA-HQ.pkl --dpath test_sets/CelebA-HQ/images --mpath test_sets/CelebA-HQ/masks --outdir samples
def generate_images(
    ctx: click.Context,
    network_pkl: str, # 네트워크 객체 피클의 경로
    dpath: str, # 원본 이미지 경로
    mpath: Optional[str], # 마스크 경로
    resolution: int, # 이미지 해상도
    truncation_psi: float, # 부동 소수점 값
    noise_mode: str, # 노이즈 모드
    outdir: str, # 결과 저장 공간
):
    #Generate images using pretrained network pickle.

    """난수 고정"""
    seed = 240  # pick up a random number. seed가 동일하면 생성 난수도 동일하다.
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # 랜덤, 넘파이, 파이토치, GPU 관련 난수 생성 모두에서 동일한 값을 내기 위해 세팅을 하는 것.

    
    """이미지 및 마스크 불러오기"""
    # 이미지 폴더에 있는 파일을 모두 리스트에 저장.
    print(f'Loading data from: {dpath}')
    img_list = sorted(glob.glob(dpath + '/*.png') + glob.glob(dpath + '/*.jpg'))
    # 마스크 폴더에 있는 파일을 모두 리스트에 저장.
    if mpath is not None:
        print(f'Loading mask from: {mpath}')
        mask_list = sorted(glob.glob(mpath + '/*.png') + glob.glob(mpath + '/*.jpg'))
        assert len(img_list) == len(mask_list), 'illegal mapping' # 이미지와 마스크의 개수가 다르면 오류 메시지 출력.

    
    """pickle을 통해 생성기 불러오기"""
    print(f'Loading networks from: {network_pkl}') # 네트워크 피클을 통해 객체를 그대로 불러온다.
    device = torch.device('cuda') #torch.device는 연산이 수행되는 위치를 나타낸다. 이는 CPU 혹은 GPU이다.
    # with: 파일 처리와 같은 리소스 관리 구문, 리소스를 조작하는 파일 핸들러는 with 구분이 끝나는 순간 닫힌다.
    with dnnlib.util.open_url(network_pkl) as f: # 파일 핸들러는 f라는 변수에 할당되고, open_url 함수를 통해 해당 파일에 접근한다.
        G_saved = legacy.load_network_pkl(f)['G_ema'].to(device).eval().requires_grad_(False) # type: ignore
        # 해당 생성자는 styleGAN2의 모델이다.
        # G_ema는 이동 평균(EMA, Exponential Moving Average)을 적용한 생성자 모델이다.
        # 파일 핸들러 f의 G_ema 키를 이용하여 밸류 값인 파일을 가져와 GPU에 탑재하고, 테스트 모드로 변경한다.
        # model.eval(): 훈련 중에 적용되는 드롭아웃(dropout) 및 배치 정규화(batch normalization)와 같은 학습 시에만 활성화되는 기능들이 비활성화됩니다.
        #               이는 학습이 아닌 추론 단계에서 일관된 결과를 얻기 위해 사용됩니다. 다시 학습 단계로 돌리려면 model.train() 명령어 사용.
        # model.requires_grad_(False): 그래디언트 업데이트 비활성화.
    net_res = 512 if resolution > 512 else resolution # 이미지 해상도를 최대 512로 설정.
    G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=net_res, img_channels=3).to(device).eval().requires_grad_(False)
    # MAT 코드를 통해 생성된 모델 객체를 GPU에 탑재. c_dim은 클래스 디멘션을 나타낸다. 어차피 테스트 이미지이므로 클래스는 고려하지 않기 때문에 0으로 둔다.
    copy_params_and_buffers(G_saved, G, require_all=True) #require_all=True 인자는 모든 매개변수 및 버퍼를 복사해야 함을 나타냅니다.
    # G_saved 모델의 매개변수와 버퍼를 G 모델로 복사합니다.
    
    os.makedirs(outdir, exist_ok=True)

    # no Labels.
    label = torch.zeros([1, G.c_dim], device=device) # 이미지에 대한 레이블은 없으므로 이에 대한 정보를 담은 벡터는 [1, 0]의 벡터이다.
    #1행 0열이라는 표현이 이상하지만, 결국 아무 것도 없는 빈 텐서를 의미한다.

    # 이미지 파일을 읽어와서 넘파이 어레이로 변환하는 함수.
    # 이렇게 함수 내부에 함수를 구현할 수도 있다.
    def read_image(image_path):
        with open(image_path, 'rb') as f:
            if pyspng is not None and image_path.endswith('.png'):
                image = pyspng.load(f.read())
            else:
                image = np.array(PIL.Image.open(f))
        if image.ndim == 2: # 흑백 이미지일 경우 3차원으로 복제
            image = image[:, :, np.newaxis] # HW => HWC
            image = np.repeat(image, 3, axis=2)
        image = image.transpose(2, 0, 1) # HWC => CHW
        image = image[:3]
        return image

    # 넘파이 어레이를 이미지 파일로 다시 바꿔주는 함수
    def to_image(image, lo, hi):
        image = np.asarray(image, dtype=np.float32) # 일단은 실수로 받고
        image = (image - lo) * (255 / (hi - lo)) # low, high 임계값을 안에서 정규화
        image = np.rint(image).clip(0, 255).astype(np.uint8) # 그 후 이미지를 담은 unsigned int 8 bit로 형변환
        image = np.transpose(image, (1, 2, 0)) # HWC
        if image.shape[2] == 1: # 흑백이면 3차원으로
            image = np.repeat(image, 3, axis=2)
        return image

    if resolution != 512:
        noise_mode = 'random' # 512 해상도가 아니면 랜덤 노이즈 적용

    """모델에 데이터를 넣고 이미지 생성!"""
    with torch.no_grad():
        for i, ipath in enumerate(img_list):
            # 이미지 불러오기
            # 이미지 파일의 기본 이름을 추출한 뒤, jpg 확장자를 png로 바꾼다.
            iname = os.path.basename(ipath).replace('.jpg', '.png') # jpg를 png로 바꾸는 이유는 아마 투명도 때문일 것이다.
            print(f'Prcessing: {iname}')
            image = read_image(ipath) # 이미지를 읽어서 넘파이 배열로 만들어준다.
            image = (torch.from_numpy(image).float().to(device) / 127.5 - 1).unsqueeze(0)
            # 넘파이를 파이토치 텐서로 변환하고, 값의 범위를 -1부터 1 사이로 정규화 한다.
            # unsqueeze(0)은 미니 배치의 차원이 추가된 4차원 텐서로의 변형을 의미한다.

            # 마스크 불러오기
            if mpath is not None: # 마스크 파일이 주어진 경우
                mask = cv2.imread(mask_list[i], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0 # 회색조로 읽어온 뒤 0부터 1 사이 값으로 정규화
                mask = torch.from_numpy(mask).float().to(device).unsqueeze(0).unsqueeze(0) # 이를 텐서로 바꾼 뒤 디바이스로 전송.
                # unsqueeze(0)를 두 번 해서 미니 배치를 포함한 4차원 텐서로 변형
            
            else: # 마스크 파일이 주어지지 않은 경우
                mask = RandomMask(resolution) # adjust the masking ratio by using 'hole_range'
                mask = torch.from_numpy(mask).float().to(device).unsqueeze(0)

            # 생성기에 넣고 생성된 이미지 저장
            z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device)
            # # 평균이 0이고 표준 편차가 1인 정규 분포를 따르는 난수를 생성하고, 이를 PyTorch 텐서로 변환하여 디바이스에 전송합니다. 생성된 텐서의 크기는 (1, G.z_dim)
            output = G(image, mask, z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) # 이미지 생성
            output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
            # 생성된 이미지의 채널 차원을 변경하고, 값을 0에서 255 범위로 스케일링합니다.
            # .round()를 사용하여 반올림하고, .clamp(0, 255)를 사용하여 값을 0과 255 사이로 제한합니다. 마지막으로 데이터 타입을 torch.uint8로 변환합니다.
            output = output[0].cpu().numpy()
            PIL.Image.fromarray(output, 'RGB').save(f'{outdir}/{iname}')
            # 이미지를 NumPy 배열로 변환하고, 이를 PIL.Image.fromarray를 사용하여 PIL 이미지로 변환합니다.
            #이후, .save(f'{outdir}/{iname}')를 사용하여 해당 경로에 이미지를 저장합니다.

if __name__ == "__main__":
    generate_images() # pylint: disable=no-value-for-parameter

'코드 분석 > MAT' 카테고리의 다른 글

MAT 환경설정 및 버그 리포트  (0) 2024.01.22
MAT: mat.py  (1) 2024.01.11
MAT: basic_module.py  (0) 2024.01.03
MAT: training_loop.py  (0) 2023.12.28
MAT: train.py  (1) 2023.12.27