코드 분석/MAT

MAT: train.py

상솜공방 2023. 12. 27. 23:19

해당 코드는 모델을 훈련시키는 플로우를 담고 있다.

코드를 이해하기 위한 기본 지식을 먼저 정리한 뒤, 하단에 코드를 첨부한다.

 

더보기

경로 길이 정규화 (Path Length Regularization)
목적: 경로 길이 정규화는 생성된 이미지의 품질을 향상시키기 위해 사용됩니다. 이 기법은 잠재 공간에서의 작은 변화가 출력 이미지에 얼마나 큰 영향을 미치는지를 측정하고, 이를 최적화하는 데 목적이 있습니다.
작동 원리: 잠재 공간의 두 점 사이의 거리와 이에 해당하는 출력 이미지 사이의 거리를 비교합니다. 이 거리가 일정하게 유지되도록 모델을 정규화함으로써, 잠재 공간에서의 변화가 출력 이미지에 일관되게 반영되도록 합니다.
효과: 이를 통해 모델이 더 안정적이고 일관된 출력을 생성하며, 특히 복잡한 데이터셋에서 더 나은 결과를 얻을 수 있습니다.


트렁케이션 (Truncation)
목적: 트렁케이션 트릭은 생성된 이미지의 다양성과 현실성 사이의 균형을 맞추기 위해 사용됩니다.
작동 원리: 잠재 공간의 벡터를 특정 임계값으로 제한(트렁케이션)하여, 모델이 너무 극단적인 또는 비현실적인 이미지를 생성하는 것을 방지합니다.
효과: 트렁케이션을 적용하면 이미지의 현실성은 증가하지만, 다양성은 감소할 수 있습니다. 따라서, 이 트릭은 주로 품질이 중요한 경우에 사용됩니다.


지수 이동 평균 (Exponential Moving Average, EMA)
목적: 생성기의 가중치에 대한 지수 이동 평균을 계산하여 훈련 과정에서의 모델 가중치 변동을 안정화시키는 데 사용됩니다.
작동 원리: 각 훈련 단계에서의 모델 가중치를 지수적으로 감소하는 가중치로 평균화합니다. 즉, 최근 가중치에 더 높은 가중치를 두고, 오래된 가중치에는 낮은 가중치를 부여합니다.
효과: 이 방법은 훈련 과정에서 발생할 수 있는 급격한 가중치 변화를 완화시켜, 최종 모델이 더 안정적인 결과를 생성하도록 돕습니다. 이는 특히 장기간 훈련에서 유용합니다. 이러한 기법들은 GAN의 성능을 최적화하고, 생성된 이미지의 품질을 향상시키는 데 중요한 역할을 합니다.

 

데이터 증강 모드

1. ADA (Adaptive Data Augmentation)
목적: 'ADA'는 적응적 데이터 증강(Adaptive Data Augmentation)의 약자로, 훈련 과정에서 데이터 증강의 강도를 동적으로 조절합니다.
작동 원리: ADA는 훈련 과정에서 판별기(Discriminator)의 성능을 모니터링하고, 판별기가 훈련 데이터를 너무 쉽게 구별하는 경우(즉, 과적합이 발생하는 경우) 데이터 증강의 강도를 증가시킵니다. 반대로, 판별기가 충분히 학습되지 않은 경우 증강의 강도를 감소시킵니다.
효과: 이 방법은 특히 데이터셋의 크기가 작을 때 유용하며, 과적합을 방지하고 모델의 일반화 능력을 향상시킵니다.


2. NoAug (No Augmentation)
목적: 'NoAug'는 데이터 증강을 전혀 사용하지 않는 모드입니다.
작동 원리: 이 모드에서는 원본 훈련 데이터를 그대로 사용하며, 어떠한 형태의 증강도 적용하지 않습니다.
적용 상황: 데이터셋이 충분히 크고 다양할 때, 또는 모델이 데이터의 특정 특성을 학습하는 것이 중요할 때 사용됩니다. 데이터 증강이 모델 성능에 부정적인 영향을 미칠 수 있는 경우에도 적합합니다.


3. Fixed (Fixed Augmentation)
목적: 'Fixed' 모드는 고정된 강도의 데이터 증강을 적용합니다.
작동 원리: 사용자가 설정한 고정된 강도(p 값)로 데이터 증강을 적용합니다. 이 강도는 훈련 과정 전반에 걸쳐 일정하게 유지됩니다.
적용 상황: 특정 수준의 데이터 증강이 필요하다고 판단될 때 사용됩니다. 예를 들어, 실험적인 목적으로 특정 증강 기법의 영향을 분석하고자 할 때 유용합니다. 각 증강 모드는 훈련 데이터셋의 크기, 다양성, 특정 실험적 요구 사항에 따라 선택될 수 있으며, GAN의 성능과 생성된 이미지의 품질에 영향을 미칩니다.

 

데이터 증강 파이프라인

데이터 증강 파이프라인은 GAN 훈련에서 사용되는 다양한 이미지 변환 기법을 의미합니다. 'blit', 'geom', 'color'는 이러한 증강 파이프라인의 예시입니다. 각각은 다음과 같은 방식으로 이미지를 변환합니다:


1. Blit (블리팅)
목적: 'Blit' 증강은 이미지의 일부를 다른 이미지로 대체하는 기법입니다.
작동 원리: 이 방법은 주로 이미지의 일부 영역을 가리거나, 다른 이미지의 일부를 현재 이미지에 오버레이하는 데 사용됩니다. 예를 들어, 이미지의 일부를 다른 배경으로 대체할 수 있습니다.
효과: 이 증강은 모델이 이미지의 주요 객체에 집중하도록 돕고, 배경이나 불필요한 부분에 대한 과적합을 방지하는 데 도움이 될 수 있습니다.


2. Geom (기하학적 변환)
목적: 'Geom' 증강은 이미지에 기하학적 변환을 적용하는 기법입니다.
작동 원리: 이 방법에는 회전, 크기 조절, 비율 변경, 이미지 자르기 등이 포함될 수 있습니다. 예를 들어, 이미지를 회전시키거나, 크기를 조절하여 다양한 크기의 객체를 학습할 수 있도록 합니다.
효과: 기하학적 변환은 모델이 다양한 방향과 크기의 객체를 인식하고, 더 강력한 일반화 능력을 갖도록 돕습니다.


3. Color (색상 변환)
목적: 'Color' 증강은 이미지의 색상을 변환하는 기법입니다.
작동 원리: 이 방법에는 밝기, 대비, 채도, 색조 등의 조정이 포함될 수 있습니다. 예를 들어, 이미지의 밝기를 조절하거나, 색상을 변화시켜 다양한 조명 조건이나 색상 환경에서도 객체를 인식할 수 있도록 합니다.
효과: 색상 변환은 모델이 다양한 색상 조건에서도 안정적으로 객체를 인식하고, 색상에 대한 과적합을 방지하는 데 도움이 됩니다.
이러한 증강 파이프라인은 GAN 모델이 더 다양한 시나리오와 조건에서 안정적으로 작동하도록 돕고, 생성된 이미지의 질을 향상시키는 데 중요한 역할을 합니다.

데이터 로드

데이터 로딩 과정에서 사용되는 '메모리 핀 설정', '워커 수', '프리페치 인자'는 데이터 로딩의 효율성과 성능을 최적화하기 위한 중요한 설정들입니다. 각각에 대해 자세히 설명하겠습니다:

1. 메모리 핀 설정 (Pin Memory)
목적: 메모리 핀 설정은 데이터 로더가 텐서를 CPU 메모리에 고정(pin)하여 GPU로의 데이터 전송 속도를 높이는 데 사용됩니다.
작동 원리: '핀된' 메모리는 페이지 교체(page swapping)가 발생하지 않아, GPU로 데이터를 더 빠르게 전송할 수 있습니다. 이는 특히 대규모 데이터셋을 사용할 때 GPU의 데이터 로딩 병목 현상을 줄이는 데 도움이 됩니다.
적용 방식: PyTorch에서는 DataLoader의 pin_memory 옵션을 True로 설정하여 이 기능을 활성화할 수 있습니다.

 

2. 워커 수 (Number of Workers)
목적: 워커 수는 데이터 로딩을 위해 생성되는 별도의 프로세스 수를 지정합니다.
작동 원리: 여러 워커를 사용하면 데이터 로딩 작업을 병렬로 수행할 수 있어, CPU와 I/O 리소스를 보다 효율적으로 활용할 수 있습니다.
적용 방식: PyTorch의 DataLoader에서 num_workers 옵션으로 워커의 수를 설정할 수 있습니다. 워커 수는 시스템의 CPU 코어 수와 I/O 성능에 따라 최적화될 수 있습니다.

 

3. 프리페치 인자 (Prefetch Factor)
목적: 프리페치 인자는 각 워커가 미리 로드해야 할 배치 수를 결정합니다.
작동 원리: 프리페치는 현재 처리 중인 배치와 동시에 다음 배치를 미리 로드함으로써, 데이터 로딩 대기 시간을 최소화합니다.
적용 방식: PyTorch 1.7 이상에서 DataLoader의 prefetch_factor 옵션을 사용하여 이 값을 설정할 수 있습니다. 높은 프리페치 인자는 더 많은 메모리를 사용하지만, 데이터 로딩의 지연 시간을 줄일 수 있습니다.이러한 설정들은 데이터 로딩 과정의 성능을 최적화하여 전체적인 모델 훈련 시간을 단축시키는 데 중요한 역할을 합니다. 특히 대규모 데이터셋과 복잡한 모델 구조를 사용할 때 이러한 설정의 조정은 훈련 효율성에 큰 영향을 미칠 수 있습니다.

 

desc

desc는 "description"의 약자로 사용되며, 훈련 데이터셋의 이름을 나타내는 문자열 변수입니다. 이 변수는 일반적으로 훈련 과정의 로깅, 결과 저장, 또는 훈련 실행에 대한 식별 정보를 제공하는 데 사용됩니다.

목적: desc 변수는 훈련 세션을 설명하거나 식별하는 데 사용되는 짧은 문자열로, 훈련 과정을 나타내는 데 도움이 됩니다. 예를 들어, 훈련 중 생성된 파일 이름이나 로그에 이 정보를 포함시켜 어떤 데이터셋을 사용했는지 쉽게 식별할 수 있습니다.
사용 방식: training_set.name은 training_set 객체의 name 속성에서 데이터셋의 이름을 가져옵니다. 이 이름은 데이터셋을 구별하는 데 유용한 정보를 포함할 수 있으며, 훈련 과정이나 결과를 기록할 때 참조됩니다.
예시: 만약 training_set이 "CIFAR-10" 데이터셋을 나타낸다면, training_set.name은 "CIFAR-10"이 될 수 있고, desc 변수는 이 값을 저장하여 훈련 로그나 결과 파일에 "CIFAR-10"이라는 식별자를 포함시킬 수 있습니다. 결국, desc는 훈련 과정을 보다 명확하게 기록하고 관리하는 데 도움을 주는 간단하지만 유용한 정보를 제공합니다.

 

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Train a GAN using the techniques described in the paper
"Training Generative Adversarial Networks with Limited Data"."""

import os # 운영 체제와 상호 작용하는 기능을 제공합니다. 파일 시스템을 탐색하고, 파일 경로를 관리하며, 디렉토리를 생성하는 등의 작업에 사용됩니다.
import click # 파이썬 스크립트에 명령줄 인터페이스를 쉽게 추가할 수 있게 해주는 라이브러리입니다. 사용자로부터 입력을 받고, 스크립트의 매개변수를 설정하는 데 사용됩니다.
import re # 정규 표현식을 사용하여 문자열을 처리합니다. 이는 파일 이름에서 특정 패턴을 찾거나, 문자열을 분석하고 변환하는 데 유용합니다.
import json # JSON 데이터를 파싱하고 생성하는 기능을 제공합니다. 설정 파일을 읽거나, 훈련 옵션을 JSON 형식으로 저장하는 데 사용됩니다.
import tempfile # 시 파일과 디렉토리를 생성하는 데 사용됩니다. 이는 훈련 중 임시 데이터를 저장하거나, 분산 훈련 시 필요한 임시 파일을 관리하는 데 사용될 수 있습니다.
import torch # PyTorch 라이브러리로, 딥러닝 모델을 구축하고 훈련하는 데 사용됩니다.
import dnnlib # NVIDIA에서 제공하는 딥러닝 네트워크 라이브러리입니다. 이는 GAN 훈련에 필요한 다양한 도구와 유틸리티 함수를 제공합니다.

from training import training_loop # GAN 훈련을 위한 주요 루프를 포함하는 모듈입니다. 이 모듈은 실제로 모델을 훈련시키는 과정을 관리합니다.
# from training import training_loop_simmim as training_loop
# from training import training_loop_woMap as training_loop
from metrics import metric_main # 훈련된 모델의 성능을 평가하는 데 사용되는 메트릭을 계산하는 함수를 포함합니다.
from torch_utils import training_stats # PyTorch와 관련된 유틸리티 함수를 제공합니다. 이는 훈련 통계를 관리하고, 사용자 정의 PyTorch 연산을 지원하는 데 사용됩니다.
from torch_utils import custom_ops

#----------------------------------------------------------------------------

class UserError(Exception):
    pass

#----------------------------------------------------------------------------

def setup_training_loop_kwargs(
    # General options (not included in desc).
    gpus       = None, # Number of GPUs: <int>, default = 1 gpu
    snap       = None, # Snapshot interval: <int>, default = 50 ticks
    metrics    = None, # List of metric names: [], ['fid50k_full'] (default), ...
    seed       = None, # Random seed: <int>, default = 0

    # Dataset.
    data       = None, # Training dataset (required): <path>
    data_val   = None, # Validation dataset: <path>, default = None. If none, data_val = data
    dataloader = None, # Dataloader, string
    cond       = None, # Train conditional model based on dataset labels: <bool>, default = False
                       # (조건부 훈련) 조건부 모델은 데이터셋의 레이블이나 클래스 정보를 사용하여 특정 유형의 이미지를 생성하는 것을 의미한다
    
    subset     = None, # Train with only N images: <int>, default = all
                       # (부분 집합 훈련) 전체 데이터셋이 아닌, 데이터셋의 부분 집합을 사용하여 모델을 훈련시키는 데 사용
                       # 예를 들어, 데이터셋의 일부만 사용하여 빠른 실험을 수행하고자 할 때 유용
    
    mirror     = None, # Augment dataset with x-flips: <bool>, default = False

    # Base config.
    cfg        = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'.
                       # 훈련에 사용될 기본 구성을 결정. 'auto', 'stylegan2', 'paper256' 등과 같은 사전 정의된 구성 중에서 선택할 수 있으며,
                       # 이는 네트워크 아키텍처, 학습률, 배치 크기 등의 기본 설정을 결정
    
    generator  = None, # Path of the generator class, 생성기 클래스의 경로
    wdim       = None, # 생성기에서 사용하는 w 벡터의 차원 크기
                       # w 벡터는 잠재 공간에서 스타일을 표현하는 데 사용되며, 이 차원은 생성된 이미지의 다양성과 품질에 영향을 미친다
    
    zdim       = None, # 잠재 공간의 벡터 차원 크기
    discriminator = None, # Path of the discriminator class, 판별기 클래스의 경로
    loss = None,       # 사용할 손실 함수의 경로
    gamma      = None, # Override R1 gamma: <float>, R1 정규화 항의 가중치
    pr         = None, # 경로 길이 정규화(path length regularization)의 비율
    pl         = None, # Train with path length regularization: <bool>, default = True, 경로 길이 정규화 여부
    kimg       = None, # Override training duration: <int>, 훈련 기간을 이미지 1000 개의 단위로 표현
    batch      = None, # Override batch size: <int>, 배치 사이즈
    truncation = None, # truncation for training: <float>, 훈련 중 트렁케이션 여부
    style_mix  = None, # style mixing probability for training: <float>, 스타일 믹스
                       # 스타일 믹싱은 다양한 잠재 벡터의 스타일을 혼합하여 더 다양한 이미지를 생성하는 기법
    
    ema        = None, # Half-life of the exponential moving average (EMA) of generator weights: <int>, 생성기 가중치의 지수 이동 평균의 반감기
    lr         = None, # learning rate
    lrt        = None, # learning rate of transformer: <float>

    # Discriminator augmentation.
    aug        = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 데이터 증강 모드
    p          = None, # Specify p for 'fixed' (required): <float>, fixed 증강 모드에서 사용할 증강 확률
    target     = None, # Override ADA target for 'ada': <float>, default = depends on aug, ADA 모드에서 사용할 ADA 목표값
    augpipe    = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc', 사용할 증강 파이프라인

    # Transfer learning.
    resume     = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>,
                       # 전이학습을 할 때, 이전에 훈련된 네트워크에서 훈련을 재개할지 여부.
    
    freezed    = None, # Freeze-D: <int>, default = 0 discriminator layers, 판별기의 몇 개의 레이어를 동결시킬지 결정하는 정수 값

    # Performance options (not included in desc).
    fp32       = None, # Disable mixed-precision training: <bool>, default = False, 혼합 정밀도 훈련을 사용할지 여부를 나타내는 불리언.
    nhwc       = None, # Use NHWC memory format with FP16: <bool>, default = False, FP16과 함께 NHWC 메모리 포맷을 사용할지 여부를 나타내는 불리언.
    allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False, PyTorch가 내부적으로 TF32를 사용할 수 있게 할지 여부를 나타내는 불리언.
    nobench    = None, # Disable cuDNN benchmarking: <bool>, default = False, cuDNN 벤치마킹을 사용하지 않을지 여부를 나타내는 불리언.
    workers    = None, # Override number of DataLoader workers: <int>, default = 3, 데이터 로더의 워커 수를 오버라이드하는 정수 값
):
    args = dnnlib.EasyDict()

    # ------------------------------------------
    # General options: gpus, snap, metrics, seed
    # ------------------------------------------

    if gpus is None:
        gpus = 1 # gpus가 None이면 (즉 default로) 일단 1 할당.
    assert isinstance(gpus, int) # 그 후 int type인지 확인.
    if not (gpus >= 1 and gpus & (gpus - 1) == 0): # 멀티 GPU 요구사항 확인.
        raise UserError('--gpus must be a power of two') # 오류 메시지 출력.
    args.num_gpus = gpus # args에 할당.

    if snap is None:
        snap = 50 # 디폴트값 50
    assert isinstance(snap, int) # int인지 확인.
    if snap < 1:
        raise UserError('--snap must be at least 1') # 오류 메시지
    args.image_snapshot_ticks = snap
    args.network_snapshot_ticks = snap
    # 이미지와 네트워크의 스냅샷 tick을 해당 값으로 설정.

    if metrics is None: # # 디폴트로 FID(Fréchet Inception Distance) 메트릭 사용.
        metrics = ['fid50k_full']
    assert isinstance(metrics, list) # 리스트 자료형인지 확인.
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
        # 모든 메트릭이 유효한지 확인하고, 하나라도 이상이 있으면 오류 메시지.
    args.metrics = metrics

    if seed is None:
        seed = 0 # 디폴트 0
    assert isinstance(seed, int)
    args.random_seed = seed

    # -----------------------------------
    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None # 훈련데이터 경로가 없으면 오류 메시지 출력
    assert isinstance(data, str) # 문자열인지 데이터 타입 확인
    if data_val is None: # data_val에 대한 추가 정보가 없으면
        data_val = data # 그냥 validation data를 data에서 사용
    if dataloader is None: # 데이터 로더의 이름은 디폴트로 다음을 사용한다.
        dataloader = 'datasets.dataset_512.ImageFolderMaskDataset'
        '''이 데이터 로더 살펴볼 필요 있을 듯'''

    args.training_set_kwargs = dnnlib.EasyDict(class_name=dataloader, path=data,
                                               use_labels=True, max_size=None, xflip=False)
    args.val_set_kwargs = dnnlib.EasyDict(class_name=dataloader, path=data_val,
                                          use_labels=True, max_size=None, xflip=False)
    # 훈련 및 검증 데이터셋을 로드할 떄 필요한 추가적인 설정을 아규먼트에 저장.
    # 여기에는 데이터 로더 클래스의 이름, 데이터 경로, 레이블 사용 여부, 최대 데이터 크기, x-flip 사용 여부가 포함됨.
    # dnnlib.EasyDict는 간단하게, 데이터를 받아 딕셔너리 형태로 만들어주는 함수라고 생각하자.
    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)
    # 데이터 로딩에 필요한 추가 조건을 아규먼트에 저장.
    # 메모리 핀 설정, 워커 수, 프리페치 인자.
    # 지금 여기까지는 아규먼트를 저장한 것.

    try:
        # training part
        '''construct_class_by_name 부분 좀 자세히 봐둘 것!'''
        training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
        # 위에서 입력했던 아규먼트를 바탕으로 실제 데이터셋 객체를 생성.
        args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
        args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
        # 생성된 데이터셋 객체의 정보를 바탕으로 아규먼트 내용을 추가 저장.
        # 데이터셋의 해상도, 레이블 사용 여부, 데이터셋의 개수가 이에 해당됨.
        # 이 아규먼트를 나중에 추가하는 이유는, 직접 데이터를 불러온 후에야 알 수 있는 정보들이기 때문이다.
        
        desc = training_set.name
        # validation part
        val_set = dnnlib.util.construct_class_by_name(**args.val_set_kwargs)
        # 마찬가지로 검증 데이터의 아규먼트를 바탕으로 데이터셋 객체 생성.
        args.val_set_kwargs.resolution = val_set.resolution
        args.val_set_kwargs.use_labels = val_set.has_labels
        args.val_set_kwargs.max_size = len(val_set)
        # 객체를 참조하여 해상도, 레이블 여부, 이미지 개수 추가 정보를 아규먼트에 기입
        del training_set, val_set # conserve memory
        # 생성된 객체를 삭제.
    except IOError as err:
        raise UserError(f'--data: {err}')
        # 해당 과정에서 오류가 있었다면 오류 메시지 추력.

    if cond is None:
        cond = False # 디폴트는 클래스 분류가 없는 데이터에 대한 학습으로 진행
    assert isinstance(cond, bool) # 자료형 확인
    if cond: # 만약 클래스가 분류된 데이터라고 옵션에 기입했는데
        if not args.training_set_kwargs.use_labels or not args.val_set_kwargs.use_labels:
            raise UserError('--cond=True requires labels specified in labels.json')
        desc += '-cond' # 실제 데이터는 클래스가 없다면 오류 메시지 출력
    else:
        args.training_set_kwargs.use_labels = False
        args.val_set_kwargs.use_labels = False
        # 아니면 아규먼트를 False로 하고 넘어감.

    if subset is not None: # 부분 집합 훈련 옵션을 줬다면
        assert isinstance(subset, int) # 데이터 타입 확인
        if not 1 <= subset <= args.training_set_kwargs.max_size:
            raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}') # 일단 1 이하면 오류
        desc += f'-subset{subset}' # desc에 정보 추가
        if subset < args.training_set_kwargs.max_size:
            args.training_set_kwargs.max_size = subset
            args.training_set_kwargs.random_seed = args.random_seed
            # 그 후, 트레이닝 셋의 데이터 크기와 시드를 입력 받은 것으로 축소.

    if mirror is None:
        mirror = False # 디폴트는 거짓
    assert isinstance(mirror, bool)
    if mirror: # 미러를 쓸 거면
        desc += '-mirror'
        args.training_set_kwargs.xflip = True # 정보 추가

    # ------------------------------------
    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    if cfg is None:
        cfg = 'auto' # 디폴트는 auto
    assert isinstance(cfg, str)
    desc += f'-{cfg}'

    cfg_specs = {
        'auto':      dict(ref_gpus=-1, kimg=25000,  mb=-1, mbstd=-1, fmaps=-1,  lrate=-1,     gamma=-1,   ema=-1,  ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
        # 해상도와 GPU 수에 따라 동적으로 하이퍼 파라미터를 설정.
        'stylegan2': dict(ref_gpus=8,  kimg=25000,  mb=32, mbstd=4,  fmaps=1,   lrate=0.002,  gamma=10,   ema=10,  ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
        'places256': dict(ref_gpus=8,  kimg=50000,  mb=64, mbstd=8,  fmaps=1,   lrate=0.002,  gamma=10,   ema=10,  ramp=None, map=8),
        'places512': dict(ref_gpus=8,  kimg=50000,  mb=64, mbstd=8,  fmaps=1,   lrate=0.002,  gamma=10,   ema=10,  ramp=None, map=8),
        'celeba512': dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  fmaps=1,   lrate=0.002,  gamma=10,   ema=10,  ramp=None, map=8),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto': # auto는 스펙에 따라 알아서 맞춰준다.
        desc += f'{gpus:d}'
        spec.ref_gpus = gpus
        res = args.training_set_kwargs.resolution
        spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
        spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
        spec.ema = spec.mb * 10 / 32

    # 생성기 디폴트는 mat 모델에서 가져온다.
    if generator is None:
        generator = 'networks.mat.Generator'
    else:
        desc += '-' + generator.split('.')[1]
    # 판별기 디폴트는 mat 모델에서 가져온다.
    if discriminator is None:
        discriminator = 'networks.mat.Discriminator'
    # 생성기의 w 차원 디폴트 값 할당.
    if wdim is None:
        wdim = 512
    # 판별기의 z 차원 디폴트 값 할당.
    if zdim is None:
        zdim = 512
    
    '''이 부분 좀 자세히 봐둘 것!'''
    # 그리고 생성기와 판별기의 정보를 아규먼트에 저장.
    args.G_kwargs = dnnlib.EasyDict(class_name=generator, z_dim=zdim, w_dim=wdim, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
    args.D_kwargs = dnnlib.EasyDict(class_name=discriminator)
    args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = spec.map
    # args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
    # args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
    # args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd
    args.D_kwargs.mbstd_group_size = spec.mbstd

    # 생성기와 판별기의 러닝 레이트 설정
    if lr is not None:
        assert isinstance(lr, float)
        spec.lrate = lr
        desc += f'-lr{lr:g}'
        
    # 변형기의 러닝 레이트 설정
    if lrt is not None:
        assert isinstance(lrt, float)
        spec.lrt = lrt
        desc += f'-lrt{lrt:g}'

    # 생성기와 판별기의 최적화 함수 설정.
    if lrt is None:
        args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8)
    else:
        args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, lrt=spec.lrt, betas=[0, 0.99], eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8)

    # 오차 함수
    if loss is None:
        loss = 'losses.loss.TwoStageLoss' # 디폴트
    else:
        desc += '-' + loss.split('.')[-1]
    args.loss_kwargs = dnnlib.EasyDict(class_name=loss, r1_gamma=spec.gamma)

    args.total_kimg = spec.kimg
    args.batch_size = spec.mb
    args.batch_gpu = spec.mb // spec.ref_gpus
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    # cfg가 cifar 데이터셋 최적화 모델이라면, 다음 제한을 걸어준다.
    if cfg == 'cifar':
        args.loss_kwargs.pl_weight = 0 # disable path length regularization
        args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
        args.D_kwargs.architecture = 'orig' # disable residual skip connections

    # R1 정규화 가중치
    if gamma is not None:
        assert isinstance(gamma, float)
        if not gamma >= 0:
            raise UserError('--gamma must be non-negative')
        desc += f'-gamma{gamma:g}'
        args.loss_kwargs.r1_gamma = gamma

    # 경로 길이 정규화 비율
    if pr is not None:
        assert isinstance(pr, float)
        desc += f'-pr{pr:g}'
        args.loss_kwargs.pcp_ratio = pr

    # 경로 길이 정규화 사용 여부
    if pl is None:
        pl = True
    assert isinstance(pl, bool)
    if pl is False:
        desc += f'-nopl'
        args.loss_kwargs.pl_weight = 0 # disable path length regularization

    # 훈련 기간을 '천 이미지' 단위로 설정
    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError('--kimg must be at least 1')
        desc += f'-kimg{kimg:d}'
        args.total_kimg = kimg

    # 배치 사이즈
    if batch is not None:
        assert isinstance(batch, int)
        if not (batch >= 1 and batch % gpus == 0):
            raise UserError('--batch must be at least 1 and divisible by --gpus')
        desc += f'-batch{batch}'
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # 트렁케이션 계수
    if truncation is not None:
        assert isinstance(truncation, float)
        desc += '-tc' + str(truncation)
        args.loss_kwargs.truncation_psi = truncation

    # 스타일 믹스 확률
    if style_mix is not None:
        assert isinstance(style_mix, float)
        desc += '-sm' + str(style_mix)
        args.loss_kwargs.style_mixing_prob = style_mix

    # 생성기 weights의 지수 이동 평균 반감기를 설정
    if ema is not None:
        assert isinstance(ema, int)
        desc += '-ema' + str(ema)
        args.ema_kimg = ema

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug is None:
        aug = 'ada' # 데이터 증강은 디폴트가 ada
    else:
        assert isinstance(aug, str)
        desc += f'-{aug}'

    if aug == 'ada':
        args.ada_target = 0.6

    elif aug == 'noaug':
        pass # 아무것도 안 하므로

    # fixed augmentation에 대해선 p 값을 검증하고 아규먼트에 넣어준다.
    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')

    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert isinstance(p, float)
        if aug != 'fixed':
            raise UserError('--p can only be specified with --aug=fixed')
        if not 0 <= p <= 1:
            raise UserError('--p must be between 0 and 1')
        desc += f'-p{p:g}'
        args.augment_p = p

    # ada aug일 경우 target을 검증하고 아규먼트에 넣어준다.
    if target is not None:
        assert isinstance(target, float)
        if aug != 'ada':
            raise UserError('--target can only be specified with --aug=ada')
        if not 0 <= target <= 1:
            raise UserError('--target must be between 0 and 1')
        desc += f'-target{target:g}'
        args.ada_target = target
        

    # 증강 파이프라인에 대한 검증 및 아규먼트에 정보 저장
    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = 'bgc'
    else:
        if aug == 'noaug':
            raise UserError('--augpipe cannot be specified with --aug=noaug')
        desc += f'-{augpipe}'

    # 다양한 파이프라인 종류들
    augpipe_specs = {
        'blit':   dict(xflip=1, rotate90=1, xint=1),
        'geom':   dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':  dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter': dict(imgfilter=1),
        'noise':  dict(noise=1),
        'cutout': dict(cutout=1),
        'bg':     dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':    dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'bgcf':   dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
        'bgcfn':  dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
        'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
    }

    assert augpipe in augpipe_specs
    if aug != 'noaug':
        args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe])

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    # 사전 학습된 stylegan2 모델의 url을 포함하는 딕셔너리
    resume_specs = {
        'ffhq256':     'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':     'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':    'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':  'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    assert resume is None or isinstance(resume, str)
    # None이거나 noresume이면 새로운 훈련 시작
    if resume is None:
        resume = 'noresume'
    elif resume == 'noresume':
        desc += '-noresume'
    # 만약 위의 스펙 중에 하나를 지시했으면
    elif resume in resume_specs:
        desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume] # predefined url
    else: # 스펙 외의 정보를 입력했으면, 따로 그 url로 가서 모델 불러올 것.
        desc += '-resumecustom'
        args.resume_pkl = resume # custom path or url

    # 전이학습을 시작할 경우, ada가 더 빠르게 반응하도록, EMA 램프업을 비활성화 함.
    if resume != 'noresume':
        args.ada_kimg = 100 # make ADA react faster at the beginning
        args.ema_rampup = None # disable EMA rampup
    '''
    적응적 데이터 증강(ADA)의 빠른 반응 설정
    목적: 전이 학습에서는 이미 훈련된 모델을 새로운 데이터셋에 적용합니다. 이 경우, 모델이 새로운 데이터셋의 특성에 빠르게 적응하는 것이 중요합니다.
    ADA의 역할: ADA는 데이터 증강의 강도를 동적으로 조절하여 과적합을 방지합니다.
    전이 학습에서는 모델이 새로운 데이터에 대해 과적합할 가능성이 높기 때문에, ADA를 통해 증강 강도를 빠르게 조절하는 것이 유용합니다.
    빠른 반응: 이미 훈련된 모델은 기존 데이터셋에 대해 이미 최적화되어 있으므로, 새로운 데이터셋에 대한 적응을 가속화하기 위해 ADA가 더 민감하게 반응하도록 설정하는 것이 도움이 됩니다.
    
    지수 이동 평균(EMA) 램프업 비활성화
    EMA 램프업: EMA 램프업은 훈련 초기에 EMA 가중치를 점진적으로 적용하는 과정입니다. 이는 훈련 초기의 불안정성을 완화하기 위해 사용됩니다.
    비활성화 이유: 전이 학습에서는 이미 안정적으로 훈련된 모델을 사용하기 때문에, 초기 불안정성이 덜 문제가 됩니다.
    따라서, EMA 램프업을 비활성화하여 모델이 새로운 데이터에 대한 학습을 더 빠르게 반영하도록 할 수 있습니다.
    효과: EMA 램프업을 비활성화함으로써, 전이 학습 과정에서 모델의 가중치 업데이트가 새로운 데이터셋에 더 직접적으로 반영됩니다.
    이는 특히 새로운 데이터셋의 특성이 기존 데이터셋과 상당히 다를 때 유용합니다.
    '''

    # Freeze-D 설정 (Freezing Discriminator Layers)
    # 판별기의 일부 레이어르 동결할지 여부.
    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError('--freezed must be non-negative')
        desc += f'-freezed{freezed:d}'
        args.D_kwargs.block_kwargs.freeze_layers = freezed
        # freeze가 None이 아니고, 0 이상의 정수이면 해당 수만큼의 레이어를 고정.
        
    # -------------------------------------------------
    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    # fp32: 32비트 부동 소수점. 이게 True이면 연산은 정확하나 속도가 느리다.
    # 이걸 False로 하면 fp16(16비트 부동 소수점)으로 연산을 진행한다.
    if fp32 is None:
        fp32 = False
    assert isinstance(fp32, bool)
    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None
        desc += '-fp32'

        
    # nhwc(메모리 포맷 종류로, 메모리 배열 방식을 나타냄)
    # 이걸 사용하면 fp16 연산에서 메모리 접근 효율성을 높일 수 있다.
    # N은 배치 크기, H는 높이, W는 너비, C는 채널을 의미.
    if nhwc is None:
        nhwc = False
    assert isinstance(nhwc, bool)
    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    
    # nobench(cuDNN 벤치마킹)
    # nobench가 True일 경우, cuDNN 벤치마킹 기능 비활성화.
    # 벤치마킹은 실시간 시간을 최적화 하기 위해 여러 알고리즘 중 최적의 것을 선택하는 과정.
    # 이걸 비활성화 하면 일관된 성능은 보장되나, 최적의 성능을 보장하진 못한다.
    if nobench is None:
        nobench = False
    assert isinstance(nobench, bool)
    if nobench:
        args.cudnn_benchmark = False

    # TF32 사용 허용
    # 이걸 True로 세팅할 경우, 파이토치 내부적으로 TF32 데이터 타입을 사용할 수 있도록 한다.
    # TF32는 NVIDIA의 Ampere GPU 아키텍처에서 지원되며, FP32와 유사한 정확도를 유지하면서 FP16과 유사한 성능을 제공한다.
    if allow_tf32 is None:
        allow_tf32 = False
    assert isinstance(allow_tf32, bool)
    if allow_tf32:
        args.allow_tf32 = True

        
    # 데이터 로딩에 사용되는 병렬 프로세스의 수를 설정합니다.
    # 더 많은 워커를 사용하면 데이터 로딩 속도가 향상될 수 있습니다.
    if workers is not None:
        assert isinstance(workers, int)
        if not workers >= 1:
            raise UserError('--workers must be at least 1')
        args.data_loader_kwargs.num_workers = workers

    return desc, args

#----------------------------------------------------------------------------
# 이 함수는 분산 훈련을 최적화한 뒤 모델 트레이닝을 실행한다.
def subprocess_fn(rank, args, temp_dir):
    # 훈련 로그를 파일에 기록하기 위한 로거를 설정.
    # 로그 파일은 args.run_dir에 log.txt 파일로 저장된다.
    dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)

    # Init torch.distributed.
    # 여러 GPU에서 분산 훈련을 하기 위한 파이토치 분산 훈련 기능을 초기화.
    if args.num_gpus > 1:
        init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
        if os.name == 'nt':
            init_method = 'file:///' + init_file.replace('\\', '/')
            torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
        else:
            init_method = f'file://{init_file}'
            torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)

    # Init torch_utils.
    # 다중 프로세싱 환경에서 훈련 통계를 동기화하기 위한 설정 초기화
    sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
    training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
    if rank != 0:
        custom_ops.verbosity = 'none'

    # Execute training loop.
    # 드디어 훈련 루프를 돌린다. 이 함수는 rank와 args를 인자로 받아 실행된다.
    training_loop.training_loop(rank=rank, **args)

#----------------------------------------------------------------------------
# click 라이브러리를 통한 CLI 환경을 구축하기 위해 전처리 작업을 한다.
class CommaSeparatedList(click.ParamType):
    name = 'list'

    def convert(self, value, param, ctx):
        _ = param, ctx
        if value is None or value.lower() == 'none' or value == '':
            return []
        return value.split(',')

#----------------------------------------------------------------------------

@click.command()
@click.pass_context

# General options.
@click.option('--outdir', help='Where to save the results', required=True, metavar='DIR')
@click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT')
@click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT')
@click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList())
@click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT')
@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)

# Dataset.
@click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True)
@click.option('--data_val', help='Validation data (directory or zip)', metavar='PATH')
@click.option('--dataloader', help='dataloader', type=str, metavar='STRING')
@click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL')
@click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT')
@click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL')

# Base config.
@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'inp512', 'paper1024', 'cifar', 'places256', 'places512', 'celeba512']))
@click.option('--generator', help='the path of generator', type=str, metavar='STRING')
@click.option('--wdim', help='dimension of w', type=int, metavar='INT')
@click.option('--zdim', help='dimension of noise input', type=int, metavar='INT')
@click.option('--discriminator', help='the path of discriminator', type=str, metavar='STRING')
@click.option('--loss', help='the path of loss', type=str, metavar='STRING')
@click.option('--gamma', help='Override R1 gamma', type=float)
@click.option('--pr', help='Override ratio of pcp loss', type=float)
@click.option('--pl', help='Enable path length regularization [default: true]', type=bool, metavar='BOOL')
@click.option('--kimg', help='Override training duration', type=int, metavar='INT')
@click.option('--batch', help='Override batch size', type=int, metavar='INT')
@click.option('--truncation', help='truncation for training', type=float)
@click.option('--style_mix', help='style mixing probability for training', type=float)
@click.option('--ema', help='Half-life of the exponential moving average (EMA) of generator weights', type=int, metavar='INT')
@click.option('--lr', help='learning rate', type=float)
@click.option('--lrt', help='learning rate', type=float)

# Discriminator augmentation.
@click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed']))
@click.option('--p', help='Augmentation probability for --aug=fixed', type=float)
@click.option('--target', help='ADA target value for --aug=ada', type=float)
@click.option('--augpipe', help='Augmentation pipeline [default: bgc]', type=click.Choice(['blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc', 'bgcf', 'bgcfn', 'bgcfnc']))

# Transfer learning.
@click.option('--resume', help='Resume training [default: noresume]', metavar='PKL')
@click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT')

# Performance options.
@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')

def main(ctx, outdir, dry_run, **config_kwargs):
    # ctx: 현재 커맨드 라인 데이터
    # outdir: 훈련 결과를 저장할 디렉토리
    # dry_run: 실제로 훈련을 시작하지 않고, 설정만 확인하는 옵션
    # **config_kwargs: 훈련에 필요한 다양한 설정을 포함하는 키워드 인자
    
    """Train a GAN using the techniques described in the paper
    "Training Generative Adversarial Networks with Limited Data".

    Examples:

    \b
    # Train with custom dataset using 1 GPU.
    python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1

    \b
    # Train class-conditional CIFAR-10 using 2 GPUs.
    python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
        --gpus=2 --cfg=cifar --cond=1

    \b
    # Transfer learn MetFaces from FFHQ using 4 GPUs.
    python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
        --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10

    \b
    # Reproduce original StyleGAN2 config F.
    python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
        --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug

    \b
    Base configs (--cfg):
      auto       Automatically select reasonable defaults based on resolution
                 and GPU count. Good starting point for new datasets.
      stylegan2  Reproduce results for StyleGAN2 config F at 1024x1024.
      paper256   Reproduce results for FFHQ and LSUN Cat at 256x256.
      paper512   Reproduce results for BreCaHAD and AFHQ at 512x512.
      paper1024  Reproduce results for MetFaces at 1024x1024.
      cifar      Reproduce results for CIFAR-10 at 32x32.

    \b
    Transfer learning source networks (--resume):
      ffhq256        FFHQ trained at 256x256 resolution.
      ffhq512        FFHQ trained at 512x512 resolution.
      ffhq1024       FFHQ trained at 1024x1024 resolution.
      celebahq256    CelebA-HQ trained at 256x256 resolution.
      lsundog256     LSUN Dog trained at 256x256 resolution.
      <PATH or URL>  Custom network pickle.
    """
    
    print('Start') # 훈련 시작 메시지
    dnnlib.util.Logger(should_flush=True) # 훈련 로그 기록

    # Setup training options.
    try:
        run_desc, args = setup_training_loop_kwargs(**config_kwargs)
        # 아규먼트를 받아 훈련에 필요한 설정 초기화
    except UserError as err:
        ctx.fail(err) # 이 과정에서 오류 발생시 오류 처리

    # Pick output directory.
    prev_run_dirs = []
    if os.path.isdir(outdir):
        prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
    prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
    prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
    cur_run_id = max(prev_run_ids, default=-1) + 1
    args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
    assert not os.path.exists(args.run_dir)
    '''
    출력 디렉토리 설정 과정
    이전 실행 디렉토리 확인:

    prev_run_dirs는 이전에 생성된 훈련 실행 디렉토리들을 저장하는 리스트입니다.
    os.path.isdir(outdir)를 사용하여 outdir 경로가 디렉토리인지 확인합니다.
    os.listdir(outdir)를 사용하여 outdir에 있는 모든 파일과 디렉토리를 나열합니다.
    리스트 컴프리헨션을 사용하여 outdir 내의 모든 디렉토리를 prev_run_dirs에 저장합니다.
    이전 실행 ID 추출:

    prev_run_ids는 이전 실행들의 고유 ID를 저장하는 리스트입니다.
    re.match(r'^\d+', x)를 사용하여 디렉토리 이름에서 숫자로 시작하는 부분을 찾아냅니다. 이 숫자는 실행 ID를 나타냅니다.
    int(x.group())를 사용하여 추출된 문자열을 정수로 변환합니다.
    현재 실행 ID 결정:

    cur_run_id는 현재 실행에 대한 고유 ID입니다.
    max(prev_run_ids, default=-1) + 1을 사용하여 가장 큰 이전 실행 ID에 1을 더해 새로운 ID를 생성합니다.
    새로운 실행 디렉토리 생성:

    args.run_dir는 현재 실행에 대한 디렉토리 경로입니다.
    os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')를 사용하여 새로운 디렉토리 경로를 생성합니다. 여기서 {cur_run_id:05d}는 실행 ID를 5자리 숫자 형식으로 표시합니다.
    assert not os.path.exists(args.run_dir)를 사용하여 새 디렉토리가 아직 존재하지 않는지 확인합니다.
    '''

    # Print options.
    print()
    print('Training options:')
    print(json.dumps(args, indent=2))
    print()
    print(f'Output directory:   {args.run_dir}')
    print(f'Training data:      {args.training_set_kwargs.path}')
    print(f'Training duration:  {args.total_kimg} kimg')
    print(f'Number of GPUs:     {args.num_gpus}')
    print(f'Number of images:   {args.training_set_kwargs.max_size}')
    print(f'Image resolution:   {args.training_set_kwargs.resolution}')
    print(f'Conditional model:  {args.training_set_kwargs.use_labels}')
    print(f'Dataset x-flips:    {args.training_set_kwargs.xflip}')
    print()
    print('Validation options:')
    print(f'Validation data:      {args.val_set_kwargs.path}')
    print(f'Number of images:   {args.val_set_kwargs.max_size}')
    print(f'Image resolution:   {args.val_set_kwargs.resolution}')
    print(f'Conditional model:  {args.val_set_kwargs.use_labels}')
    print(f'Dataset x-flips:    {args.val_set_kwargs.xflip}')
    print()

    # Dry run?
    # 이 경우 실제 모델 훈련을 진행하진 않고 함수를 빠져나감.
    if dry_run:
        print('Dry run; exiting.')
        return

    # Create output directory.
    print('Creating output directory...')
    os.makedirs(args.run_dir) # 아웃풋 디렉토리 생성
    with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f:
        json.dump(args, f, indent=2)
        # 훈련 옵션을 json 파일로 저장해둔다.

    # Launch processes.
    print('Launching processes...')
    torch.multiprocessing.set_start_method('spawn') # 멀티 프로세싱 방법 설정.
    # 임시 디렉토리 설정.
    with tempfile.TemporaryDirectory() as temp_dir:
        if args.num_gpus == 1: # GPU가 1개면 해당 조건으로 트레이닝 루프 진행.
            subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
        else:
            torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)

#----------------------------------------------------------------------------

if __name__ == "__main__":
    main() # pylint: disable=no-value-for-parameter

#----------------------------------------------------------------------------

 

setup_training_loop_kwargs가 받는 인자

더보기
parameter 설명 arguments
gpus GPU 개수 args.num_gpus
snap 이미지와 네트워크의 스냅샷 저장 tick args.image_snapshot_ticks
args.network_snapshot_ticks
metrics 평가 메트릭 지정 args.metrics = metrics
seed 랜덤 시드 설정 args.random_seed
data 트레이닝 데이터 경로 args.training_set_kwargs
data_val 밸리데이션 데이터 경로 args.val_set_kwargs
dataloader 데이터 로더 이름 args.data_loader_kwargs
cond 레이블 데이터 여부 args.training_set_kwargs.use_labels
args.val_set_kwargs.use_labels
subset 데이터 일부만을 활용할 것인가? args.training_set_kwargs.max_size
args.training_set_kwargs.random_seed
mirror 미러를 통한 데이터 증강 여부 args.training_set_kwargs.xflip
cfg 데이터에 따른 최적의 하이퍼 파라미터 결정
러닝 레이트, 배치 사이즈 등을 설정
args.training_set_kwargs.resolution
generator 생성기 클래스 경로 args.G_kwargs
wdim 생성기의 w 벡터 차원 args.G_kwargs
zdim 잠재 공간의 z 벡터 차원 args.G_kwargs
discriminator 판별기 클래스 경로 args.D_kwargs
lr 생성기와 판별기의 러닝 레이트 args.G_kwargs
args.D_kwargs
lrt 변형기의 러닝 레이트 args.G_kwargs
args.D_kwargs
loss 손실 함수 경로 args.loss_kwargs
gamma R1 정규화 가중치 args.loss_kwargs
pr 경로 길이 정규화 비율 args.loss_kwargs
pl 경로 길이 정규화 여부 args.loss_kwargs
truncation 생성기 훈련 중 트렁케이션 여부 args.loss_kwargs
style_mix 스타일 믹스 비율 args.loss_kwargs
ema 생성기 weights의 지수 이동 평균 반감기 args.ema_kimg
kimg 훈련 기간을 1000개의 이미지 단위로 결정 args.total_kimg
batch 배치 사이즈 args.batch_size
aug 데이터 증강 모드(ada, noaug, fixed) None
p fixed 증강에서의 증강 확률 args.augment_p
target ada 증강에서 ada 목표값 args.ada_target
augpipe 증강 파이프라인(blit, color, filter 등) args.augment_kwargs
resume 전이학습 여부 args.resume_pkl
args.ada_kimg
args.ema_rampup
freezed 판별기 동결 레이어 개수 args.D_kwargs
fp32 32비트 부동 소수점 사용 여부 args.G_kwargs
nhwc 메모리 포맷을 nhwc로 할지 여부 args.G_kwargs
allow_tf32 TF32 데이터타입 사용 여부 rgs.allow_tf32
nobench cuDNN 벤치마크 여부 args.cudnn_benchmark
workers 데이터 로더 개수 args.data_loader_kwargs