코드 분석/MAT

MAT: training_loop.py

상솜공방 2023. 12. 28. 20:40

모델 훈련에 대한 프로세스를 담당하는 코드이다.

해당 코드를 이해하기 위한 정보를 먼저 정리한 후 코드 분석으로 넘어간다.

 

이론에 대한 이해

더보기

생성기 가중치의 EMA 반감기

EMA(Exponential Moving Average)의 가중치 반감기는 EMA 계산에서 새로운 데이터 포인트가 이전 데이터에 비해 얼마나 강한 영향을 미치는지를 결정하는 지표입니다. 반감기(half-life)는 EMA 가중치가 절반으로 줄어드는 데 필요한 시간을 나타냅니다.
EMA 정의: EMA는 시계열 데이터에서 최근의 관측값에 더 큰 가중치를 부여하여 평균을 계산하는 방법입니다. 이는 데이터의 최근 추세를 더 잘 반영할 수 있게 해줍니다.
반감기의 역할: 반감기는 EMA를 계산할 때 적용되는 감쇠율(decay rate)을 결정합니다. 반감기가 길수록 과거 데이터에 더 많은 가중치가 부여되며, 반감기가 짧을수록 최근 데이터에 더 많은 가중치가 부여됩니다.

모델 가중치에 적용: 딥러닝에서 EMA는 모델의 가중치를 안정화시키는 데 사용됩니다.  특히, GAN(Generative Adversarial Networks)과 같이 불안정한 학습 경향이 있는 모델에서 유용합니다.
가중치 반감기: 이는 모델 가중치의 업데이트 속도를 조절합니다. 반감기가 길면 모델이 과거 학습 경험을 더 오래 유지하고, 반감기가 짧으면 최근의 학습 경험에 더 빠르게 반응합니다.
안정성과 반응성의 균형: EMA의 가중치 반감기를 조절함으로써, 모델의 안정성과 반응성 사이의 균형을 맞출 수 있습니다. 적절한 반감기 설정은 모델이 최신 데이터의 변화에 효과적으로 적응하면서도 과거 데이터로부터 얻은 안정적인 학습을 유지하도록 도와줍니다. 따라서 EMA의 가중치 반감기는 모델 학습 과정에서 중요한 하이퍼파라미터로 작용하며, 모델의 성능과 일반화 능력에 영향을 미칠 수 있습니다.

 

 

EMA(Exponential Moving Average)의 램프업(ramp-up) 계수

EMA를 적용할 때 초기 학습 단계에서 EMA 가중치를 어떻게 조절할지 결정하는 값입니다. 램프업은 특히 모델이 초기 학습 단계에 있을 때, EMA를 부드럽게 적용하기 위해 사용됩니다.


EMA 램프업의 목적
1. 초기 불안정성 완화: 모델이 초기 학습 단계에 있을 때, 가중치는 종종 불안정할 수 있습니다. 램프업은 이러한 초기 불안정성을 완화하는 데 도움을 줍니다.
2. 부드러운 전환: EMA 램프업은 모델이 학습 초기에 과거 가중치와 새로운 가중치 사이의 급격한 변화를 피하고, 점진적으로 EMA를 적용하도록 합니다.


램프업 계수의 역할
1. EMA 적용 강도 조절: 램프업 계수는 학습 초기에 EMA의 적용 강도를 조절합니다. 램프업 기간 동안 EMA 가중치는 점차적으로 증가하며, 이는 초기 가중치 변화를 더 부드럽게 만듭니다.
2. 학습 진행에 따른 조정: 램프업 기간이 끝나면, EMA는 정상적인 감쇠율로 돌아가 모델 가중치에 영향을 미칩니다.


실용적 의미
1. 모델 성능 향상: 램프업을 사용함으로써, 모델은 초기 학습 단계에서 발생할 수 있는 급격한 가중치 변화로 인한 부정적인 영향을 피할 수 있습니다. 이는 특히 GAN과 같이 민감한 모델에서 중요합니다.
2. 하이퍼파라미터 조정: 램프업 계수는 하이퍼파라미터로서, 모델의 초기 학습 과정과 EMA 적용 방식을 조정하는 데 사용됩니다. 램프업 계수를 적절히 설정하는 것은 모델의 초기 학습 과정을 안정화시키고, 전체적인 학습 성능을 향상시키는 데 중요한 역할을 합니다.

 

 

생성기(Generator)와 판별기(Discriminator)의 정규화

모델의 학습 안정성을 향상시키고 고품질의 결과를 얻기 위해 필요합니다. 각각의 정규화 방법은 다음과 같은 목적을 가지고 있습니다:

생성기의 정규화
과적합 방지: 생성기가 훈련 데이터에 과도하게 적응하는 것을 방지합니다. 이는 특히 제한된 데이터셋에서 중요합니다.
다양성 유지: 생성기가 다양한 출력을 생성할 수 있도록 도와줍니다. 이는 모델이 더 다양한 이미지를 생성하는 데 도움이 됩니다.
학습 안정성 개선: 생성기의 정규화는 학습 과정 중 발생할 수 있는 수렴 문제를 완화시키는 데 도움이 됩니다.


판별기의 정규화
과적합 방지: 판별기가 훈련 데이터의 특정 패턴에 지나치게 의존하는 것을 방지합니다.
학습 균형 유지: GAN에서는 생성기와 판별기 간의 균형이 중요합니다. 판별기가 너무 강해지면 생성기가 충분히 좋은 결과를 생성하지 못할 수 있습니다.
모델 안정성 개선: 판별기의 정규화는 전체 GAN 시스템의 안정성을 개선하는 데 기여합니다.

 

정규화 기법
드롭아웃(Dropout): 무작위로 뉴런을 비활성화하여 과적합을 방지합니다.
배치 정규화(Batch Normalization): 미니배치의 출력을 정규화하여 학습을 안정화시킵니다.
스펙트럴 정규화(Spectral Normalization): 판별기에서 주로 사용되며, 가중치의 스펙트럼을 제한하여 학습 안정성을 개선합니다. 정규화는 GAN의 학습 과정을 안정화시키고, 생성된 이미지의 품질을 높이는 데 중요한 역할을 합니다. 적절한 정규화 기법을 사용함으로써, 모델의 성능을 최적화하고 원하는 결과를 얻을 수 있습니다.

 

 

ADA(Adaptive Data Augmentation)

생성적 적대 신경망(GAN)과 같은 딥러닝 모델에서 데이터 증강의 효과를 동적으로 조절하는 기법입니다. ADA는 모델이 과적합을 방지하고 더 나은 일반화를 달성할 수 있도록 돕습니다. ADA의 주요 구성 요소는 다음과 같습니다:

ADA의 목표값 (ada_target):
이 값은 데이터 증강이 적용되어야 하는 정도를 나타냅니다. 예를 들어, ada_target=0.6은 판별기가 60%의 시간 동안 증강된 데이터에 대해 '진짜'라고 판단하도록 학습하는 것을 목표로 합니다. 목표값은 증강의 강도를 조절하는 데 사용되며, 이를 통해 모델이 훈련 데이터에 과적합되는 것을 방지합니다.


ADA 조정 간격 (ada_interval):
이는 ADA가 얼마나 자주 조정되어야 하는지를 나타내는 값입니다. 예를 들어, ada_interval=4는 매 4번의 훈련 반복마다 ADA를 조정한다는 의미입니다. 조정 간격은 증강의 동적 조절을 위한 빈도를 결정합니다.


ADA 조정 속도 (ada_kimg):
이 값은 ADA가 얼마나 빠르게 증강 확률을 조절하는지를 나타냅니다. 예를 들어, ada_kimg=500은 증강 확률이 1단위 변화하는 데 500K 이미지(킬로 이미지)가 필요하다는 것을 의미합니다. 조정 속도는 증강 확률의 변화 속도를 결정하며, 너무 빠른 조정은 학습에 불안정성을 초래할 수 있습니다.


ADA의 중요성
과적합 방지: ADA는 특히 데이터가 제한적인 경우에 모델이 훈련 데이터에 과적합되는 것을 방지하는 데 도움을 줍니다.
학습 안정성 향상: 적절한 데이터 증강은 모델의 학습 안정성을 개선하고, 결과적으로 더 나은 성능을 달성할 수 있도록 합니다.
동적 조절: ADA는 학습 과정을 통해 증강의 강도를 동적으로 조절함으로써, 모델이 다양한 데이터 조건에 적응할 수 있도록 합니다.

 

커스터마이즈 된 데이터셋 생성 및 로딩

더보기
# 훈련 데이터셋 생성
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs)

# 무한 데이터 샘플러 생성
training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)

# 데이터 반복자(iterator) 생성
training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))

 

훈련 데이터셋 생성

폴더에 따라 데이터셋을 제작하는 기능은 파이토치에서 다음과 같이 기본적으로 제공한다.

trainset = torchvision.datasets.ImageFolder(root = '/home/white/Desktop/ship', transform = transform)

그러나 해당 코드에선 dnnlib.util.construct_class_by_name()라는 커스텀 함수를 사용해 객체를 생성한다.

 

무한 데이터 샘플러 생성

데이터를 한 번만 순회할 때는 아래와 같이 적는 게 보편적이다. 해당 코드는 기본적인 샘플러를 사용하기 때문에 따로 아규먼트를 적지 않았다.

trainloader = Dataloader(trainset, batch_size = 16, shuffle = False, num_workers = 4)

그러나 GAN에서는 무한하게 데이터를 로드해줘야 한다.  그 이유는 다음과 같다.

1. 연속적인 훈련: GAN은 생성자(Generator)와 판별자(Discriminator) 두 네트워크가 서로 경쟁하며 학습합니다. 이 과정은 연속적이며, 한 에폭(epoch)에서 다음 에폭으로 넘어갈 때 중단 없이 데이터를 지속적으로 공급받아야 합니다. 무한 샘플러는 데이터셋의 끝에 도달해도 자동으로 다시 시작하여 이러한 연속적인 훈련을 지원합니다.

2. 동적 균형 유지: GAN 훈련은 생성자와 판별자 간의 동적 균형을 유지하는 것이 중요합니다. 무한 샘플러를 사용하면, 훈련 데이터셋을 여러 번 반복하여 사용함으로써 두 네트워크가 충분한 학습을 할 수 있도록 합니다. 이는 특히 데이터셋이 작을 때 중요합니다.
3. 효율적인 GPU 사용: GAN 훈련은 계산적으로 매우 집약적입니다. 무한 샘플러를 사용하면 GPU가 데이터 로딩으로 인해 유휴 상태에 빠지는 것을 최소화하고, 항상 훈련 데이터를 처리할 수 있도록 합니다.
4. 분산 훈련 지원: 대규모 GAN 모델의 경우, 여러 GPU 또는 노드에서 훈련을 분산시키는 것이 일반적입니다. 무한 샘플러는 각 GPU가 데이터셋의 다른 부분을 독립적으로 처리할 수 있도록 하여, 분산 훈련의 효율성을 높입니다.
5. 에폭 관리의 단순화: 전통적인 데이터 로더를 사용할 경우, 각 에폭의 끝에서 추가적인 관리가 필요할 수 있습니다. 하지만 무한 샘플러를 사용하면 이러한 관리가 필요 없어져, 훈련 코드를 단순화할 수 있습니다.

 

데이터 반복자 생성

training_set_iterator는 훈련 데이터셋에 대한 반복자(iterator) 역할을 합니다. 이 반복자는 훈련 과정 중에 데이터를 효율적으로 로드하고 처리하는 데 사용됩니다. training_set_iterator의 주요 역할과 특징은 다음과 같습니다:

배치 단위 데이터 제공: training_set_iterator는 DataLoader를 사용하여 데이터셋에서 데이터를 배치(batch) 단위로 로드합니다. 각 배치는 모델 훈련에 필요한 일정량의 데이터(예: 이미지, 레이블)를 포함합니다. 참고로 데이터 로더는 해당 함수 안에 있는 코드 'torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)'를 통해 만들어진다.
데이터 셔플링 및 샘플링: DataLoader는 데이터셋에서 데이터를 무작위로 셔플링하여 모델이 데이터의 순서에 의존하지 않도록 합니다. 또한, 사용자 정의 샘플러(InfiniteSampler 등)를 통해 데이터를 특정 방식으로 샘플링할 수 있습니다.
멀티스레딩 및 멀티프로세싱 지원: DataLoader는 멀티스레딩 또는 멀티프로세싱을 사용하여 데이터 로딩을 병렬화할 수 있습니다. 이를 통해 데이터 로딩 시간을 줄이고 GPU가 효율적으로 활용될 수 있도록 합니다.
무한 반복 지원: InfiniteSampler와 같은 사용자 정의 샘플러를 사용할 경우, training_set_iterator는 데이터셋을 무한히 순회할 수 있습니다. 이는 특히 GAN과 같은 모델에서 중요한데, 이는 훈련 과정이 데이터셋의 한 번의 순회로 제한되지 않기 때문입니다.
효율적인 리소스 활용: training_set_iterator를 사용하면 데이터 로딩과 모델 훈련을 동시에 진행할 수 있어, GPU와 CPU 리소스를 효율적으로 활용할 수 있습니다. 요약하자면, training_set_iterator는 훈련 데이터를 효율적으로 로드하고 관리하는 데 필수적인 역할을 하며, 특히 대규모 데이터셋과 복잡한 모델 훈련에 있어서 중요한 구성 요소입니다.

 

생성기와 판별기

더보기
    # Construct networks.
    if rank == 0: # GPU 랭크가 0일 때, 즉 첫 시작임을 알리는 것.
        print('Constructing networks...')
    # common_kwargs에는 생성기와 판별기에 모두 들어가는 아규먼트(레이블 차원), 이미지 해상도, 이미지 채널 수를 저장.
    common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
    # 생성기 옵션과 공통 옵션을 인자로 받아 생성기 객체를 생성. train()을 통해 훈련 모드로 설정 후 to.(device)로 GPU로 옮긴다.
    G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
    # 판별기 옵션과 공통 옵션을 인자로 받아 판별기 객체를 생성.
    D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
    # G의 모델 구조를 복사해 G_ema를 생성.
    G_ema = copy.deepcopy(G).eval()

 

G_ema

코드에선 생성기인 G, 판별기인 D, 그리고 지수 이동 평균 버전의 생성기 G_ema, 이렇게 세 개의 객체를 생성합니다. G_ema (Exponential Moving Average 버전의 생성자)를 별도로 만드는 이유는 GAN(Generative Adversarial Networks) 훈련 과정에서 생성자의 안정성과 결과물의 질을 향상시키기 위함입니다. G_ema는 원래 생성자 G의 가중치를 지수 이동 평균 방식으로 평활화한 버전입니다. 이 방법은 다음과 같은 이유로 사용됩니다:

 

결과의 안정성 향상: GAN 훈련 과정은 종종 매우 불안정할 수 있습니다. 생성자의 가중치는 훈련 과정에서 크게 변동할 수 있는데, 이는 최종 생성 이미지의 질에 영향을 미칩니다. G_ema는 이러한 변동을 평활화하여 더 안정적인 결과를 생성합니다.
질적 개선: 지수 이동 평균은 최근 가중치에 더 큰 가중치를 두면서도 이전 가중치들을 고려합니다. 이는 생성자가 최근의 학습을 반영하면서도 과거의 정보를 유지할 수 있게 해, 생성된 이미지의 질을 개선합니다.
과적합 방지: 특히 훈련 데이터가 제한적인 경우, G_ema는 과적합을 방지하는 데 도움이 될 수 있습니다. 지수 이동 평균은 모델이 특정 훈련 샘플에 지나치게 최적화되는 것을 방지합니다.
평가 및 테스트에 적합: 일반적으로 모델을 평가하거나 실제 사용 환경에 배포할 때는 G_ema와 같이 더 안정적이고 일관된 결과를 생성하는 모델 버전을 사용하는 것이 바람직합니다. G_ema의 사용은 GAN의 특성상 발생할 수 있는 훈련 과정의 불안정성을 완화하고, 최종 생성물의 질을 높이는 데 중요한 역할을 합니다. 따라서, G와 별도로 G_ema를 유지하고 관리하는 것은 GAN 모델의 성능을 최적화하는 데 중요한 전략 중 하나입니다.

 

훈련 과정: G_ema가 G의 deepcopy로 시작하므로 둘은 동일한 네트워크 구조와 가중치를 가집니다. 그러나 훈련 과정에서 G_ema의 업데이트 방식이 G와 다릅니다. G_ema는 G의 가중치를 직접적으로 학습하지 않고, 대신 G의 가중치 변화를 지수 이동 평균(Exponential Moving Average, EMA) 방식으로 추적합니다. 이 과정은 다음과 같이 이루어집니다:

1. EMA 가중치 업데이트: 각 훈련 단계에서 G의 가중치가 업데이트됩니다. 이때, G_ema의 가중치는 G의 새로운 가중치와 G_ema의 이전 가중치를 결합하여 업데이트됩니다. 이 결합은 EMA 공식에 따라 수행됩니다.
2. EMA 공식: EMA는 다음과 같은 공식을 사용합니다: ema_weight = beta * ema_weight + (1 - beta) * current_weight. 여기서 ema_weight는 G_ema의 가중치, current_weight는 G의 현재 가중치, 그리고 beta는 0과 1 사이의 감쇠 계수입니다. 이 계수는 G_ema가 과거 가중치에 얼마나 많은 가중치를 두는지 결정합니다.
3. 가중치 평활화: EMA는 최근 가중치에 더 큰 가중치를 두면서도 이전 가중치들을 고려합니다. 이는 G_ema가 G의 최신 학습을 반영하면서도 과거의 정보를 유지하게 합니다. 결과적으로, G_ema는 G의 가중치 변화를 더 부드럽게 추적하는 버전이 됩니다. 이는 G의 순간적인 변동성을 줄이고, 훈련 과정에서 발생할 수 있는 불안정성을 완화하는 데 도움이 됩니다. G_ema는 특히 생성된 이미지의 질을 평가하거나 모델을 실제 환경에 배포할 때 유용하게 사용됩니다.

 

G_ema의 deepcopy()

네트워크의 깊은 복사본(deep copy)이란, 네트워크의 모든 구성 요소(가중치, 구조, 상태 등)를 완전히 복제하는 것을 의미합니다. 깊은 복사는 원본 객체의 정확한 복제본을 생성하며, 복제된 객체는 원본 객체와 완전히 독립적입니다. 이는 Python의 copy.deepcopy() 함수를 사용하여 수행됩니다.
깊은 복사의 특징:
완전한 독립성: 깊은 복사본은 원본 객체와 완전히 독립적입니다. 원본 객체의 변경이 복사본에 영향을 미치지 않으며, 복사본의 변경이 원본에 영향을 미치지 않습니다.
모든 하위 요소 포함: 깊은 복사는 객체의 모든 하위 요소(중첩된 객체, 리스트, 딕셔너리 등)도 함께 복사합니다. 이는 단순히 최상위 레벨의 객체만 복사하는 얕은 복사(shallow copy)와 대비됩니다.
메모리 상의 별도 위치: 깊은 복사본은 메모리 상에서 원본 객체와 별도의 위치에 저장됩니다. 이는 두 객체가 서로 독립적인 메모리 공간을 가진다는 것을 의미합니다.
네트워크에서의 깊은 복사:
딥러닝에서 네트워크의 깊은 복사본을 만들 때, 이는 네트워크의 모든 가중치, 파라미터, 구조, 학습 상태 등을 포함합니다. 예를 들어, PyTorch에서 생성자 네트워크 G의 깊은 복사본을 만들 경우, G의 모든 층, 가중치, 학습된 파라미터 등이 복사되어 완전히 독립적인 새로운 네트워크가 생성됩니다.
이러한 깊은 복사는 특히 모델의 상태를 특정 시점에 보존하거나, 원본 모델을 변경하지 않고 실험을 진행하고자 할 때 유용합니다. 예를 들어, GAN에서 생성자 G의 상태를 보존하면서 지수 이동 평균 버전 G_ema를 별도로 업데이트하고자 할 때 깊은 복사를 사용합니다.

 

G_ema를 eval()로 설정하는 이유

G_ema (Exponential Moving Average 버전의 생성자)를 eval() 모드로 설정하는 이유는 이 모델이 훈련 과정에서 학습되지 않고, 대신 추론(inference) 또는 평가(evaluation) 목적으로만 사용되기 때문입니다. eval() 모드로 설정하는 것은 다음과 같은 중요한 목적을 가집니다:
배치 정규화(Batch Normalization) 및 드롭아웃(Dropout) 비활성화: eval() 모드는 배치 정규화와 드롭아웃과 같은 특정 층들이 훈련 모드와 다르게 동작하도록 합니다. 훈련 모드에서는 이러한 층들이 활성화되어 있지만, 추론 시에는 이들을 비활성화하여 모델의 일관된 출력을 보장합니다.
추론 및 평가에 최적화: G_ema는 주로 생성된 이미지의 질을 평가하거나 실제 환경에서 모델을 사용할 때 활용됩니다. eval() 모드는 이러한 추론 및 평가 작업에 최적화된 상태를 제공합니다.
학습 과정에서의 변경 방지: G_ema는 G의 가중치를 지수 이동 평균으로 평활화한 버전으로, 훈련 과정에서 직접적으로 학습되지 않습니다. eval() 모드는 G_ema가 훈련 과정 중에 가중치가 변경되는 것을 방지합니다.
성능 평가의 일관성: G_ema를 사용하여 생성된 이미지의 질을 평가하거나, 모델의 성능을 다른 모델과 비교할 때, eval() 모드는 모델의 출력이 훈련 데이터의 특정 배치에 의존하지 않고 일관되게 유지되도록 합니다. 요약하자면, G_ema를 eval() 모드로 설정하는 것은 모델이 학습 과정에서 변경되지 않고, 추론 및 평가 작업에 최적화되며, 일관된 결과를 제공하도록 하는 데 중요합니다. 이는 특히 생성자의 성능을 평가하고, 실제 환경에서의 사용에 있어서 중요한 역할을 합니다.

 

전체 코드

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os # 파일 시스템 및 경로 관리
import time # 프로그램 실행 시간 측정용
import copy # 객체 복사
import json # json 데이터를 다루기 위한 라이브러리
import pickle # 파이썬 객체를 저장 및 로드 하기 위한 라이브러리
import psutil # 시스템 메모리 사용량, CPU 사용량 등을 측정
import PIL.Image # Python Image Library의 이미지 처리 모듈. 이미지 로드 및 변환, 저장에 사용
import numpy as np # 다차원 행렬 연산용 라이브러리
import torch # 파이토치
import dnnlib # NVIDIA의 GAN 훈련과 관련된 라이브러리
from torch_utils import misc # 일반적인 유틸리티 기능 포함. 텐서와 모델을 GPU, CPU로 이동. 모델의 가중치 초기화. 분산 훈련 지원. 데이터 타입 변환, 텐서 크기 조정 등
from torch_utils import training_stats # 훈련 과정의 다양한 통계를 기록. loss, acc 등의 메트릭을 기록 및 시각화. 성능 모니터링 및 분산 훈련의 통계 동기화 및 통합
from torch_utils.ops import conv2d_gradfix # 컨볼루션 연산의 그래디언트를 보정
from torch_utils.ops import grid_sample_gradfix # 그리드 샘플링의 그래디언트를 보정

import legacy # 이전 버전의 머신러닝 모델이나 코드를 가져와서 쓸 수 있도록 해줌
from metrics import metric_main # 모델 성능 평가 메트릭

#----------------------------------------------------------------------------

def setup_snapshot_image_grid(training_set, random_seed=0):
    rnd = np.random.RandomState(random_seed) # 랜덤 값 생성
    gw = np.clip(7680 // training_set.image_shape[2], 7, 32) # 그리드 가로 길이 / 이미지 가로 길이, 최소 7, 최대 32로 제한
    gh = np.clip(4320 // training_set.image_shape[1], 4, 32) # 그리드 세로 길이 / 이미지 세로 길이, 최소 4, 최대 32로 제한

    # 데이터 레이블이 없다면
    # No labels => show random subset of training samples.
    if not training_set.has_labels:
        all_indices = list(range(len(training_set))) # 훈련 데이터셋의 모든 인덱스를 리스트로 생성
        rnd.shuffle(all_indices) # 생성된 인덱스 리스트를 무작위로 섞는다
        grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] # 섞인 리스트에서 그리드 개수만큼 인덱스를 뽑는다.

    # 데이터 레이블이 있다면
    # Group training samples by label.
    else:
        label_groups = dict() # label => [idx, ...] 클래스를 담을 딕셔너리 선언
        for idx in range(len(training_set)): # 데이터 전체를 순환하면서
            label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) # 각 이미지의 레이블 추출
            if label not in label_groups: # 뽑은 레이블이 레이블 그룹에 없다면
                label_groups[label] = [] # 새로운 레이블 그룹을 생성한 뒤 추가한다
            label_groups[label].append(idx) # 그리고 그 그룹에 현재 이미지의 인덱스를 추가한다

        # Reorder.
        label_order = sorted(label_groups.keys()) # 레이블 그룹을 키에 따라 순서대로 정렬하여 label_order를 생성
        for label in label_order: # 여기의 각 카테고리에 대하여
            rnd.shuffle(label_groups[label]) # 인덱스를 셔플하고

        # Organize into grid.
        grid_indices = [] # 그리드에 사용될 인덱스를 저장할 리스트 생성
        for y in range(gh): # 그리드의 세로 크기만큼 돌면서
            label = label_order[y % len(label_order)] # 현재 행에 해당하는 레이블 선택
            indices = label_groups[label] # 그 레이블의 인덱스 그룹에서 인덱스 리스트를 가져옴
            grid_indices += [indices[x % len(indices)] for x in range(gw)] # 현재 행에 해당하는 이미지 인덱스를 그리드 인덱스 리스트에 추가
            label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] # 사용된 인덱스 업데이트

    # Load data.
    images, masks, labels = zip(*[training_set[i] for i in grid_indices]) # 인덱스에 해당하는 이미지, 마스크, 레이블을 로드
    return (gw, gh), np.stack(images), np.stack(masks), np.stack(labels) # 그리드 크기와 이미지, 마스크, 레이블을 반환

#----------------------------------------------------------------------------

def save_image_grid(img, fname, drange, grid_size):
    lo, hi = drange # 이미지 데이터의 동적 범위(dynamic range)를 lo(최소값)와 hi(최대값)로 설정
    img = np.asarray(img, dtype=np.float32) # 입력된 이미지를 NumPy 배열로 변환
    img = (img - lo) * (255 / (hi - lo)) # 이미지 데이터를 [lo, hi] 범위에서 [0, 255] 범위로 정규화
    img = np.rint(img).clip(0, 255).astype(np.uint8) # 정규화된 이미지 데이터를 반올림하고, 0과 255 사이의 값으로 제한한 후, 8비트 정수형으로 변환

    gw, gh = grid_size # 그리드의 가로(gw)와 세로(gh) 크기를 설정합니다.
    _N, C, H, W = img.shape # 이미지 배열의 차원을 N(이미지 수), C(채널 수), H(높이), W(너비)로 분해
    img = img.reshape(gh, gw, C, H, W) # 이미지 배열을 그리드 형태로 재배열
    img = img.transpose(0, 3, 1, 4, 2) # 재배열된 이미지의 차원을 조정하여 그리드 형태로 만듦
    img = img.reshape(gh * H, gw * W, C) # 그리드 이미지를 2D 이미지 형태로 변환

    assert C in [1, 3] # 이미지가 그레이스케일(1채널) 또는 RGB(3채널)인지 확인
    if C == 1: # 그레이스케일 이미지인 경우, PIL 라이브러리를 사용하여 이미지를 저장
        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
    if C == 3: # RGB 이미지인 경우, PIL 라이브러리를 사용하여 이미지를 저장
        PIL.Image.fromarray(img, 'RGB').save(fname)

#----------------------------------------------------------------------------

def training_loop(
    run_dir                 = '.',      # Output directory. 훈련 결과(모델 체크포인트, 이미지 등)를 저장할 디렉토리.

    # kwargs
    training_set_kwargs     = {},       # Options for training set. 훈련 데이터셋에 대한 설정.
    val_set_kwargs          = {},       # Options for validation set. 검증 데이터셋에 대한 설정.
    data_loader_kwargs      = {},       # Options for torch.utils.data.DataLoader. 데이터 로더의 데이터 로딩과 배치 생성 방식을 정의.
    G_kwargs                = {},       # Options for generator network. 생성기(Generator) 네트워크에 대한 설정.
    D_kwargs                = {},       # Options for discriminator network. 판별기(Discriminator) 네트워크에 대한 설정.
    G_opt_kwargs            = {},       # Options for generator optimizer. 생성기의 옵티마이저(optimizer) 설정.
    D_opt_kwargs            = {},       # Options for discriminator optimizer. 판별기의 옵티마이저 설정.
    augment_kwargs          = None,     # Options for augmentation pipeline. None = disable. 데이터 증강 파이프라인에 대한 설정.
    loss_kwargs             = {},       # Options for loss function. 손실 함수에 대한 설정.

    metrics                 = [],       # Metrics to evaluate during training. 훈련 중 평가할 메트릭의 리스트.
    random_seed             = 0,        # Global random seed. 전역 난수 시드.
    num_gpus                = 1,        # Number of GPUs participating in the training. 훈련에 사용할 GPU의 수.
    rank                    = 0,        # Rank of the current process in [0, num_gpus]. 현재 프로세스의 순위입니다. 분산 훈련 시 사용.
    batch_size              = 4,        # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. 한 번의 훈련 반복(iteration)에 사용할 총 배치 크기.
    batch_gpu               = 4,        # Number of samples processed at a time by one GPU. 각 GPU가 한 번에 처리할 샘플의 수.
    ema_kimg                = 10,       # Half-life of the exponential moving average (EMA) of generator weights. 생성기 가중치의 지수 이동 평균(EMA)의 반감기.
    ema_rampup              = None,     # EMA ramp-up coefficient. EMA의 램프업 계수.
    G_reg_interval          = 4,        # How often to perform regularization for G? None = disable lazy regularization. 생성기의 정규화를 수행하는 간격.
    D_reg_interval          = 16,       # How often to perform regularization for D? None = disable lazy regularization. 판별기의 정규화를 수행하는 간격.
    augment_p               = 0,        # Initial value of augmentation probability. fixed 증강의 초기 확률.
    ada_target              = None,     # ADA target value. None = fixed p. ada 증강의 목표값.
    ada_interval            = 4,        # How often to perform ADA adjustment? ada 조정 간격.
    ada_kimg                = 500,      # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. ada 조정 속도.
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images. 총 훈련 길이.
    kimg_per_tick           = 4,        # Progress snapshot interval. 스냅샷 저장 인터벌.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = disable. 이미지 스냅샷 저장 인터벌.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = disable. 네트워크 스냅샷 저장 인터벌.
    resume_pkl              = None,     # Network pickle to resume training from. 전이 학습 등 추가적인 훈련을 하기 위해 가져올 네트워크 피클 파일.
    cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark? cuDNN 벤치마크를 활성화 할지 여부.
    allow_tf32              = False,    # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32? 파이토치의 TF32 자료형을 사용할 것인지 여부.
    abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks. 훈련 중단을 결정하는 콜백 함수.
    progress_fn             = None,     # Callback function for updating training progress. Called for all ranks. 훈련 진행 상황을 업데이트하는 콜백 함수.
):
    # Initialize.
    start_time = time.time() # 현재 시간을 기록하여 훈련 시작 시간을 저장
    device = torch.device('cuda', rank) # GPU 장치 설정, rank는 여러 GPU 중 어느 것을 사용할지 결정하는 인덱스
    
    # 난수 생성기 시드 설정
    np.random.seed(random_seed * num_gpus + rank) # 넘파이 랜덤 시드 설정, 이 시드는 여러 GPU를 사용할 때 각 GPU에서 다른 난수 시퀀스를 생성하기 위해 다음 수식을 사용
    torch.manual_seed(random_seed * num_gpus + rank) # 파이토치 랜덤 시드 설정
    
    # 성능 최적화 설정
    torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed. cuDNN 벤치마크 모드 활성화 여부 결정.
                                                        # cuDNN은 여러 알고리즘을 자동으로 시도하고 가장 빠른 것을 선택하여 합성곱 연산의 속도를 향상하나 정밀도가 낮아질 수 있다.
    torch.backends.cuda.matmul.allow_tf32 = allow_tf32  # Allow PyTorch to internally use tf32 for matmul. 행렬 곱 연산에 tf32 사용 여부 결정. (연산 속도 향상을 위함)
    torch.backends.cudnn.allow_tf32 = allow_tf32        # Allow PyTorch to internally use tf32 for convolutions. 합성 곱 연산에 tf32 사용 여부 결정. (연산 속도 향상을 위함)

    # 그래디언트 계산 관련 버그 수정 활성화
    conv2d_gradfix.enabled = True                       # Improves training speed. 합성곱 계층의 그래디언드 버그 수정.
    grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe. 그리드 샘플링 연산의 그래디언트 버그 수정.

    # Load training set.
    if rank == 0:
        print('Loading training set...') # 첫 번째 GPU에서만 훈련 데이터셋 로딩 메시지 출력

    # 훈련 데이터셋 생성
    training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
    # trian.py의 160 번째 줄과 동일한 코드이다. training_set_kwargs에는 훈련 데이터를 만들기 위한 다양한 옵션 사항이 존재한다.
    # 이 아규먼트는 trian.py의 145 줄 args.training_set_kwargs = dnnlib.EasyDict(class_name=dataloader, path=data, use_labels=True, max_size=None, xflip=False)를 통해 생성된다.
    # 해당 아규먼트는 CLI 인터페이스를 만드는 click이 유저의 명령어를 받아와 EasyDict 형태로 바꿔 할당한다.
    # 그리고 construct_class_by_name 함수가 이 아규먼트를 받아서 실질적인 트레이닝 셋 객체를 만드는 것이다.
    # 파이토치 내장함수로 따지자면 torchvision.dataset() 함수가 여러 인자를 받아 데이터셋을 만드는 원리인 것이다.

    # 검증 데이터셋 생성
    val_set = dnnlib.util.construct_class_by_name(**val_set_kwargs) # subclass of training.dataset.Dataset

    # 데이터셋 샘플러 설정
    # 훈련 데이터셋에 대한 무한 샘플러 생성. 이는 GPU에서 데이터를 분산 처리 하기 위한 것.
    training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
    # 파이토치 데이터로더 함수를 이용해 훈련 데이터셋에 대한 반복자 생성.
    training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
    
    # 데이터셋 정보 출력
    if rank == 0: # 첫 번째 GPU에서
        print()
        print('Num images: ', len(training_set)) # 훈련 데이터셋 이미지 수 출력
        print('Image shape:', training_set.image_shape) # 이미지의 형태 출력
        print('Label shape:', training_set.label_shape) # 레이블의 형태 출력
        print()

    
    # Construct networks.
    if rank == 0: # GPU 랭크가 0일 때, 즉 첫 시작임을 알리는 것.
        print('Constructing networks...')
    # common_kwargs에는 생성기와 판별기에 모두 들어가는 아규먼트(레이블 차원), 이미지 해상도, 이미지 채널 수를 저장.
    common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
    # 생성기 옵션과 공통 옵션을 인자로 받아 생성기 객체를 생성. train()을 통해 훈련 모드로 설정 후 to.(device)로 GPU로 옮긴다.
    G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
    # 판별기 옵션과 공통 옵션을 인자로 받아 판별기 객체를 생성.
    D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
    # G의 모델 구조를 복사해 G_ema를 생성.
    G_ema = copy.deepcopy(G).eval()

    """
    1. 왜 G와 D를 생성할 때 requires_grad_(False)로 설정하지?
    """

    # Resume from existing pickle.
    if (resume_pkl is not None) and (rank == 0): # 전이학습 할 데이터 객체의 피클의 경로가 존재하고, GPU 랭크가 0일 때(전이 학습은 첫 번째 GPU에서만 진행)
        print(f'Resuming from "{resume_pkl}"')
        with dnnlib.util.open_url(resume_pkl) as f: # 피클 객체를 가져와서 읽는다. open_ulr() 함수는 로컬 파일 뿐만 아니라 url을 통해 원격 파일에도 접근할 수 있다.
            resume_data = legacy.load_network_pkl(f) # 피클 객체를 읽기.
        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: # 'G'는 이름을, G는 네트워크 객체를 나타낸다. 이들을 하나씩 가져와서
            misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # 피클 객체의 파라미터와 버퍼를 똑같이 G에 덮어 씌운다.

    # Print network summary tables.
    if rank == 0: # 네트워크 요약은 첫 번째 GPU에서만 실행 된다.
        # G에 입력될 빈 잠재 벡터 생성
        z = torch.empty([batch_gpu, G.z_dim], device=device) # G에 입력될 잠재 벡터 z를 위한 빈 텐서를 초기화. 텐서의 크기는 [batch_gpu, G.z_dim(잠재 벡터 차원)]
        c = torch.empty([batch_gpu, G.c_dim], device=device) # G에 입력될 조건 벡터 c를 위한 빈 텐서를 초기화. 텐서의 크기는 [batch_gpu, G.c_dim(클래스 차원)]
        # adaptation to inpainting config
        
        # G에 입력될 데이터 생성 후 print_model_summary() 함수 호출
        img_in = torch.empty([batch_gpu, training_set.num_channels, training_set.resolution, training_set.resolution], device=device) # 예를 들면 [4, 3, 512, 512]의 빈 텐서
        mask_in = torch.empty([batch_gpu, 1, training_set.resolution, training_set.resolution], device=device) # 예를 들면 [4, 1, 512, 512]의 빈 텐서
        img = misc.print_module_summary(G, [img_in, mask_in, z, c]) # 해당 함수를 통해 모델과 인풋 데이터를 넣었을 때 모델의 요약 정보를 출력한다.
        
        # D
        img_stg1 = torch.empty([batch_gpu, 3, training_set.resolution, training_set.resolution], device=device) # [4, 3, 512, 512]의 빈 텐서
        misc.print_module_summary(D, [img, mask_in, img_stg1, c]) # 마찬가지로 D의 정보 요약 출력.

    # Setup augmentation.
    if rank == 0:
        print('Setting up augmentation...')
    # 우선 데이터 증강 파이프 라인과 ada 관련 통계 정보를 None으로 초기화 한다.
    augment_pipe = None
    ada_stats = None
    # augment_kwargs가 None이 아니고, augment_p (증강 확률)가 0보다 크거나 ada_target이 설정된 경우에만 데이터 증강을 설정
    if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
        # 데이터 증강을 위한 파이프라인을 동적으로 생성하고 초기화
        augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
        # 증강 파이프라인의 확률 파라미터 p를 augment_p 값으로 설정
        augment_pipe.p.copy_(torch.as_tensor(augment_p))
        # ADA가 활성화된 경우(ada_target이 None이 아닌 경우), 손실 관련 통계를 수집하기 위한 Collector 객체를 생성
        # 이는 판별자의 손실 신호를 추적하여 증강 확률을 적응적으로 조정하는 데 사용
        if ada_target is not None:
            ada_stats = training_stats.Collector(regex='Loss/signs/real')

    # Distribute across GPUs.
    if rank == 0: #  첫 번째 GPU에서 실행 중일 때("rank == 0") 여러 GPU에 걸쳐 분산 처리를 시작한다는 메시지를 출력
        print(f'Distributing across {num_gpus} GPUs...')
    ddp_modules = dict() # 산 데이터 병렬 처리(Distributed Data Parallel, DDP)를 위한 모듈을 저장할 딕셔너리를 초기화
    # 생성자 G의 mapping과 synthesis 부분, 판별자 D, EMA 버전의 생성자 G_ema, 그리고 데이터 증강 파이프라인 augment_pipe에 대해 반복 
    for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]:
        # GPU가 여러 개 있고, 모듈이 None이 아니며, 모듈에 파라미터가 있는 경우에만 분산 처리를 수행
        if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
            module.requires_grad_(True) # 모듈의 파라미터에 대해 기울기 계산을 활성화. 이는 분산 데이터 병렬 처리를 위한 사전 설정.
            # 이 클래스는 여러 GPU에 걸쳐 모듈의 작업을 분산시키고, 각 GPU에서의 기울기를 집계하여 모델을 효율적으로 훈련할 수 있게 함.
            module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False)
            # 기울기 계산을 다시 비활성화. 이는 분산 처리 후의 정리 작업.
            module.requires_grad_(False)
        if name is not None: # 이름이 지정된 모듈을 ddp_modules 딕셔너리에 저장
            ddp_modules[name] = module # 이 딕셔너리는 나중에 모듈에 접근할 때 사용

    # Setup training phases.
    # 해당 코드 블럭에선 G와 D의 정규화 인터벌을 정하고, 이에 따른 최적화된 다양한 옵션을 설정한다.
    if rank == 0:
        print('Setting up training phases...')
    # 손실 함수를 동적으로 생성하고 초기화
    loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss
    phases = [] # 훈련 단계를 저장할 리스트를 초기화
    for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
        # 튜플: (모듈의 이름, 해당 객체, 최적화 옵션, 정규화 간격), 이걸 G와 D에 대해 수행.
        
        # 정규화가 필요 없는 경우(interval == None인 경우)
        if reg_interval is None:
            # 최적화 함수를 동적으로 생성하고 초기화. 이 함수는 모듈의 파라미터와 opt_kwargs를 사용하여 최적화 클래스를 생성.
            opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
            # 생성된 최적화 함수를 훈련 단계 리스트에 추가
            phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]

        # 정규화가 필요한 경우의 설정: 이는 "Lazy regularization"이라고 불리며, 정규화를 덜 자주 수행하여 훈련 속도를 향상.
        else: # Lazy regularization.
            mb_ratio = reg_interval / (reg_interval + 1) # 정규화 간격에 따라 최적화 파라미터를 조정
            opt_kwargs = dnnlib.EasyDict(opt_kwargs) # 최적화 옵션을 EasyDict로 변환
            opt_kwargs.lr = opt_kwargs.lr * mb_ratio
            opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
            if 'lrt' in opt_kwargs:
                filter_list = ['tran', 'Tran']
                base_params = []
                tran_params = []
                for pname, param in module.named_parameters():
                    flag = False
                    for fname in filter_list:
                        if fname in pname:
                            flag = True
                    if flag:
                        tran_params.append(param)
                    else:
                        base_params.append(param)
                optim_params = [{'params': base_params}, {'params': tran_params, 'lr': opt_kwargs.lrt * mb_ratio}]
                optim_kwargs = dnnlib.EasyDict()
                for key, val in opt_kwargs.items():
                    if 'lrt' != key:
                        optim_kwargs[key] = val
            else:
                optim_params = module.parameters()
                optim_kwargs = opt_kwargs

            opt = dnnlib.util.construct_class_by_name(optim_params, **optim_kwargs) # 정규화를 위한 최적화 함수를 동적으로 생성
            phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
            phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]

    for phase in phases: # 각 훈련 단계에 대해 반복
        phase.start_event = None
        phase.end_event = None
        if rank == 0: # 첫 번째 GPU에서 실행 중일 때, 각 단계의 시작과 종료 시간을 측정할 이벤트를 생성
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    
    # Export sample images.
    # 이미지 그리드와 관련된 변수들을 초기화
    grid_size = None
    grid_z = None
    grid_c = None
    grid_img = None
    grid_mask = None
    
    if rank == 0:
        print('Exporting sample images...')
        # 검증 데이터셋에서 이미지, 마스크, 레이블을 추출하고, 이를 그리드 형태로 구성. 이 함수는 그리드 크기(grid_size), 이미지, 마스크, 레이블을 반환
        grid_size, images, masks, labels = setup_snapshot_image_grid(training_set=val_set)
        # 실제 이미지를 그리드 형태로 저장
        save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0, 255], grid_size=grid_size)
        # adaptation to inpainting config
        # 마스크를 그리드 형태로 저장
        save_image_grid(masks, os.path.join(run_dir, 'masks.png'), drange=[0, 1], grid_size=grid_size)
        
        # --------------------
        # 성자 G에 입력될 잠재 벡터 z, c를 무작위로 생성하고 배치 크기에 맞게 분할
        grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
        
        # 이미지를 텐서로 변환하고 정규화한 후 배치 크기에 맞게 분할
        grid_img = (torch.from_numpy(images).to(device) / 127.5 - 1).split(batch_gpu)  # [-1, 1]
        # 마스크를 텐서로 변환하고 배치 크기에 맞게 분할
        grid_mask = torch.from_numpy(masks).to(device).split(batch_gpu)  # {0, 1}
        # 생성자 G_ema를 사용하여 샘플 이미지를 생성하고, 이를 하나의 텐서로 결합한 후 NumPy 배열로 변환
        images = torch.cat([G_ema(img_in, mask_in, z, c, noise_mode='const').cpu() \
                            for img_in, mask_in, z, c in zip(grid_img, grid_mask, grid_z, grid_c)]).numpy()
        # --------------------
        # 성된 이미지를 그리드 형태로 저장
        save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)

    # Initialize logs.
    if rank == 0:
        print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*') # 객체를 생성하여 훈련 과정에서 발생하는 다양한 통계를 수집. regex='.*'는 모든 종류의 통계를 수집하겠다는 의미.
    stats_metrics = dict() # 훈련 과정에서 계산된 메트릭을 저장할 딕셔너리를 초기화
    # JSONL 형식의 로그 파일(stats_jsonl)과 TensorBoard 이벤트 파일(stats_tfevents)을 위한 변수를 초기화
    stats_jsonl = None
    stats_tfevents = None
    
    if rank == 0: # 첫 번째 GPU에서 실행 중일 때만 다음의 로그 관련 작업을 수행
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') # JSONL 형식의 로그 파일을 쓰기 모드로 열어 stats_jsonl에 할당. 이 파일은 훈련 과정의 통계를 저장.
        try: # TensorBoard를 사용하기 위해 필요한 모듈을 임포트
            import torch.utils.tensorboard as tensorboard
            # TensorBoard의 SummaryWriter를 생성하여 훈련 과정의 통계를 TensorBoard 형식으로 기록
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            # TensorBoard 모듈이 없는 경우, 이를 알리는 메시지를 출력하고 TensorBoard 로그 기록을 건너뜀.
            print('Skipping tfevents export:', err)

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    batch_idx = 0
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    while True:

        # Fetch training data.
        with torch.autograd.profiler.record_function('data_fetch'):
            phase_real_img, phase_mask, phase_real_c = next(training_set_iterator)
            phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
            # adaptation to inpainting config
            phase_mask = phase_mask.to(device).to(torch.float32).split(batch_gpu)
            # --------------------
            phase_real_c = phase_real_c.to(device).split(batch_gpu)
            all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
            all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
            all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
            all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
            all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]

        # Execute training phases.
        for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
            if batch_idx % phase.interval != 0:
                continue

            # Initialize gradient accumulation.
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)

            # Accumulate gradients over multiple rounds.
            for round_idx, (real_img, mask, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_mask, phase_real_c, phase_gen_z, phase_gen_c)):
                sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
                gain = phase.interval
                loss.accumulate_gradients(phase=phase.name, real_img=real_img, mask=mask, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain)

            # Update weights.
            phase.module.requires_grad_(False)
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                for param in phase.module.parameters():
                    if param.grad is not None:
                        misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
                phase.opt.step()
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_nimg = ema_kimg * 1000
            if ema_rampup is not None:
                ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
            ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Execute ADA heuristic.
        if (ada_stats is not None) and (batch_idx % ada_interval == 0):
            ada_stats.update()
            adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
            augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in stats_collector.
        tick_end_time = time.time()
        fields = []
        fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
        fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
        fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
        fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
        fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
        fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
        fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
        fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
        torch.cuda.reset_peak_memory_stats()
        fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
        training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
        training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([G_ema(img_in, mask_in, z, c, noise_mode='const').cpu() \
                                for img_in, mask_in, z, c in zip(grid_img, grid_mask, grid_z, grid_c)]).numpy()
            save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)

        # Save network snapshot.
        snapshot_pkl = None
        snapshot_data = None
        if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs), val_set_kwargs=dict(val_set_kwargs))
            for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
                if module is not None:
                    if num_gpus > 1:
                        misc.check_ddp_consistency(module, ignore_regex=[r'.*\.w_avg', r'.*\.relative_position_index', r'.*\.avg_weight', r'.*\.attn_mask', r'.*\.resample_filter'])
                    module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
                snapshot_data[name] = module
                del module # conserve memory
            snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
            if rank == 0:
                with open(snapshot_pkl, 'wb') as f:
                    pickle.dump(snapshot_data, f)

        # Evaluate metrics.
        if (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
                    dataset_kwargs=val_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)
        del snapshot_data # conserve memory

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event is not None):
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            stats_jsonl.write(json.dumps(fields) + '\n')
            stats_jsonl.flush()
        if stats_tfevents is not None:
            global_step = int(cur_nimg / 1e3)
            walltime = timestamp - start_time
            for name, value in stats_dict.items():
                stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
            for name, value in stats_metrics.items():
                stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
            stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')

#----------------------------------------------------------------------------

'코드 분석 > MAT' 카테고리의 다른 글

MAT 환경설정 및 버그 리포트  (0) 2024.01.22
MAT: mat.py  (1) 2024.01.11
MAT: basic_module.py  (0) 2024.01.03
MAT: train.py  (1) 2023.12.27
MAT: generate_image.py  (0) 2023.12.26