코드 분석/MAT

MAT: basic_module.py

상솜공방 2024. 1. 3. 13:52

코드를 이해하기 위한 기본 지식

더보기

sys.path.insert(0, '../')

이 코드는 Python의 모듈 검색 경로를 수정하는 데 사용됩니다. 이해를 돕기 위해 각 부분을 자세히 설명하겠습니다:
sys.path: Python 인터프리터가 모듈을 검색할 때 참조하는 경로 목록입니다. 이 목록에 포함된 디렉토리들은 Python이 모듈을 임포트할 때 찾아보는 위치들입니다.
insert(0, '../'): 이 메소드는 sys.path 리스트의 맨 앞에 새로운 경로를 추가합니다. 리스트의 인덱스 0은 첫 번째 위치를 의미하므로, 이 코드는 새 경로를 리스트의 가장 앞에 삽입합니다.
'../': 이 문자열은 상위 디렉토리를 가리키는 상대 경로입니다. 현재 스크립트가 있는 위치에서 한 단계 위의 디렉토리를 의미합니다.

작동 원리
Python 스크립트에서 모듈을 임포트할 때, Python 인터프리터는 sys.path에 있는 디렉토리들을 순서대로 검색하여 해당 모듈을 찾습니다. sys.path.insert(0, '../')를 사용하면, 상위 디렉토리가 검색 경로의 맨 앞에 추가됩니다. 이는 Python이 모듈을 찾을 때 상위 디렉토리를 먼저 확인하게 만듭니다. 이렇게 하면 현재 스크립트와 같은 위치에 없는 모듈이나 패키지를 임포트할 수 있게 됩니다. 특히, 프로젝트의 다른 부분에 있는 코드를 재사용하거나 모듈화된 구조를 유지할 때 유용합니다.

 

사용 예시
예를 들어, 다음과 같은 디렉토리 구조가 있다고 가정해 보겠습니다:

project/
│
├── module/
│   └── my_module.py
│
└── script/
    └── my_script.py

my_script.py에서 my_module.py를 임포트하려면, my_script.py 파일에 sys.path.insert(0, '../')를 추가한 후 from module import my_module와 같이 임포트할 수 있습니다.
이렇게 sys.path를 조작하는 것은 특정 상황에서 유용하지만, 코드의 이식성과 가독성을 저해할 수 있으므로 신중하게 사용해야 합니다. 가능하다면 상대적으로 더 안정적인 방법, 예를 들어 패키지 구조를 설정하거나 환경 변수를 사용하는 방법을 고려하는 것이 좋습니다.

 

conv2d_resample

기능: conv2d_resample은 2D 컨볼루션 연산을 수행하면서 동시에 입력 텐서를 리샘플링(업샘플링 또는 다운샘플링)하는 연산입니다.
사용 예: 이 함수는 이미지의 해상도를 변경하면서 특징을 추출할 때 유용합니다. 예를 들어, 이미지의 크기를 줄이면서 동시에 특정 특징을 강조하기 위한 컨볼루션 필터를 적용할 수 있습니다.
작동 원리: 일반적으로 컨볼루션 연산은 입력 이미지에 필터를 적용하여 특징을 추출하지만, conv2d_resample은 이 과정에 추가적으로 입력 이미지의 크기를 조정합니다. 이는 메모리 사용량을 최적화하고 연산 효율을 높이는 데 도움이 됩니다.

 

upfirdn2d

기능: upfirdn2d (Up-sample, FIR filter, Down-sample)는 업샘플링, FIR 필터링, 다운샘플링을 결합한 연산입니다.
사용 예: 이 연산은 이미지 처리에서 흔히 사용되는 기법으로, 이미지의 해상도를 변경하거나 스무딩(smoothing) 및 샤프닝(sharpening)과 같은 필터링 작업에 사용됩니다.
작동 원리: upfirdn2d는 먼저 입력 데이터를 업샘플링하여 해상도를 높인 다음, FIR 필터를 적용하여 이미지를 스무딩하거나 샤프닝합니다. 마지막으로 필요한 경우 다운샘플링을 통해 해상도를 조정합니다.

 

bias_act

기능: bias_act는 신경망의 활성화 함수에 바이어스를 추가하는 연산입니다.
사용 예: 이 함수는 신경망의 각 레이어에서 활성화 함수를 적용하기 전에 바이어스를 추가하는 데 사용됩니다. 이는 신경망의 비선형성을 증가시키고 모델의 표현력을 향상시키는 데 도움이 됩니다.
작동 원리: bias_act는 입력 텐서에 바이어스를 더한 후, 지정된 활성화 함수를 적용합니다. 이는 신경망이 더 복잡한 함수를 학습할 수 있게 하여 성능을 향상시킬 수 있습니다.

 

decorator

데코레이터는 Python에서 함수나 메소드의 동작을 수정하거나 확장하는데 사용되는 매우 유용한 도구입니다. 데코레이터는 기본적으로 다른 함수를 인자로 받아 새로운 기능을 추가한 후, 수정된 함수를 반환하는 함수입니다. 이를 통해 기존 코드를 변경하지 않고도 추가적인 기능을 쉽게 적용할 수 있습니다.

데코레이터의 기본 구조
데코레이터는 @ 기호를 사용하여 함수 정의 바로 위에 위치합니다. 기본적인 구조는 다음과 같습니다:

def my_decorator(func):
    def wrapper():
        # 여기에 추가 기능을 구현
        return func()
    return wrapper

@my_decorator
def my_function():
    print("Hello, World!")

이 예시에서 my_decorator는 데코레이터 함수이며, my_function은 데코레이터를 적용받는 대상 함수입니다. my_decorator 내부의 wrapper 함수는 my_function을 감싸는 역할을 하며, 여기에 추가적인 기능을 구현할 수 있습니다.

데코레이터의 사용 이유
코드 재사용성 향상: 데코레이터를 사용하면 특정 기능을 여러 함수에 걸쳐 재사용할 수 있습니다. 이는 코드 중복을 줄이고 유지보수를 용이하게 합니다.
코드의 가독성 및 관리: 기능을 데코레이터로 분리함으로써, 주 함수의 로직을 간결하게 유지할 수 있습니다. 또한, 데코레이터를 사용하면 기능의 추가나 제거가 간편해집니다.
확장성: 기존 함수에 영향을 주지 않고 새로운 기능을 추가할 수 있습니다. 이는 특히 프레임워크나 라이브러리에서 함수의 동작을 사용자 정의할 때 유용합니다.


@misc.profiled_function 데코레이터
@misc.profiled_function 데코레이터는 함수의 성능을 프로파일링하는 기능을 제공합니다. 이 데코레이터를 사용하면 해당 함수의 실행 시간, 메모리 사용량 등을 측정할 수 있으며, 이는 성능 최적화에 도움을 줍니다. 이 데코레이터는 아마도 함수의 성능을 분석하고 로깅하는 기능을 내부적으로 구현하고 있을 것입니다.

 

@persistence.persistent_class 데코레이터
@persistence.persistent_class 데코레이터는 클래스에 지속성(persistence) 기능을 추가하는 데 사용됩니다. 이 데코레이터의 정확한 작동 방식은 persistence 모듈의 구현에 따라 다르지만, 일반적으로 다음과 같은 기능을 수행합니다:

직렬화 지원: 클래스의 인스턴스를 저장하거나 네트워크를 통해 전송할 수 있도록 직렬화를 지원합니다. 이는 객체의 상태를 바이트 스트림으로 변환하여 파일에 저장하거나 다른 시스템으로 전송할 수 있게 해줍니다.
객체 상태의 복원: 저장된 객체의 상태를 다시 로드하여 객체를 이전 상태로 복원할 수 있습니다.
이 데코레이터는 주로 모델의 가중치와 같은 중요한 데이터를 저장하고 로드하는 데 사용되며, 모델의 학습 상태를 유지하고 필요할 때 다시 로드할 수 있는 기능을 제공합니다.

 

각 코드 블럭에 대한 설명

더보기

FullyConnectedLayer

단순 전연결 계층

 

ToRGB

중간 특징 맵을 최종적인 RGB 이미지로 변환하는 역할을 수행합니다. 이 클래스는 스타일 정보를 통합하고, 선택적으로 스킵 연결을 적용하여, 신경망의 출력을 시각적으로 풍부한 이미지로 변환하는 데 중요한 역할을 합니다.

 

Conv2dLayer
기본 컨볼루션 레이어: Conv2dLayer는 가장 기본적인 형태의 2D 컨볼루션 레이어입니다. 이 레이어는 표준적인 컨볼루션 연산을 수행합니다.
직접적인 특징 추출: 입력 텐서에 직접 컨볼루션 필터를 적용하여 특징을 추출합니다.
활성화 함수 적용: 선택적으로 활성화 함수를 적용할 수 있습니다.


ModulatedConv2d
스타일 기반 컨볼루션 레이어: ModulatedConv2d는 스타일 코드에 의해 조절되는 컨볼루션 레이어입니다. 이 레이어는 입력 텐서에 대한 컨볼루션 연산을 수행하기 전에 가중치를 스타일 코드로 조절합니다.
스타일 조절: 스타일 코드를 사용하여 컨볼루션 필터의 가중치를 동적으로 조절합니다. 이를 통해 동일한 네트워크가 다양한 스타일의 출력을 생성할 수 있습니다.
디모듈레이션: 선택적으로 디모듈레이션을 적용하여 가중치를 정규화할 수 있습니다.


StyleConv
통합된 스타일 컨볼루션 레이어: StyleConv는 ModulatedConv2d의 기능에 더해 추가적인 기능을 통합한 레이어입니다. 이 레이어는 스타일 기반의 컨볼루션 연산과 함께 노이즈 추가 및 활성화 함수 적용을 수행합니다.
노이즈 추가: 선택적으로 노이즈를 입력에 추가할 수 있습니다. 이는 출력에 무작위성을 부여하여 더 자연스러운 결과를 생성하는 데 도움이 됩니다.
활성화 및 출력 제한: 활성화 함수를 적용하고, 필요한 경우 출력값을 제한합니다.


결론
Conv2dLayer는 기본적인 컨볼루션 연산을 수행하는 레이어입니다.
ModulatedConv2d는 스타일 코드에 의해 조절되는 컨볼루션 연산을 수행합니다.
StyleConv는 ModulatedConv2d의 기능에 노이즈 추가 및 활성화 함수 적용을 통합한 레이어입니다.

 

DecBlock
DecBlock 클래스는 일반적인 디코딩 블록으로, 여러 해상도에서 사용될 수 있습니다.
이 클래스는 두 개의 StyleConv 레이어(conv0와 conv1)를 사용하여 입력 텐서에 스타일 기반의 컨볼루션을 적용합니다.
ToRGB 레이어를 통해 최종적으로 RGB 이미지를 생성합니다.
이 클래스는 다양한 해상도의 특징을 처리할 수 있도록 설계되었습니다.


DecBlockFirstV2
DecBlockFirstV2 클래스는 디코딩 과정의 첫 번째 블록으로 특별히 설계되었습니다.
이 클래스는 하나의 일반 2D 컨볼루션 레이어(conv0)와 하나의 StyleConv 레이어(conv1)를 사용합니다.
ToRGB 레이어를 통해 RGB 이미지를 생성합니다.
DecBlockFirstV2는 초기 단계의 디코딩에 특화되어 있으며, 입력 텐서를 초기 처리하는 데 중점을 둡니다.


DecBlockFirst
DecBlockFirst 클래스도 디코딩 과정의 첫 번째 블록으로 사용됩니다.
이 클래스는 FullyConnectedLayer를 사용하여 입력 텐서를 변환하고, StyleConv 레이어를 통해 스타일 기반의 컨볼루션을 적용합니다.
ToRGB 레이어를 통해 최종적으로 RGB 이미지를 생성합니다.
DecBlockFirst는 입력 텐서를 고차원 특징 공간으로 매핑하는 데 중점을 두고 있습니다.


결론
세 클래스 모두 디코딩 과정에 사용되지만, 각각의 구조와 초점이 다릅니다. DecBlock은 다양한 해상도에서 사용될 수 있는 범용적인 디코딩 블록이며, DecBlockFirstV2와 DecBlockFirst는 디코딩 과정의 첫 번째 단계에 특화된 구조를 가지고 있습니다. 이들의 주된 차이는 사용되는 레이어의 종류와 입력 텐서를 처리하는 방식에 있습니다.

 

MappingNet

생성적 적대 신경망(GAN)과 같은 생성 모델에서 중요한 역할을 하는 매핑 네트워크(mapping network)를 구현하는 PyTorch 모듈입니다. 이 클래스의 주요 목적은 입력 잠재 벡터(latent vector)와 조건 벡터(conditioning vector)를 받아, 중간 잠재 공간(intermediate latent space)으로 변환하는 것입니다. 이 변환 과정은 GAN의 스타일 변조(style modulation)에 사용됩니다.

import sys # 시스템 경로 조작 및 명령줄 인자 처리
sys.path.insert(0, '../') # sys.path 리스트에 새로운 경로인 '../'을 추가한다. 이걸 추가함으로써 상위 디렉토리에 있는 모듈들을 현재 스크립트에 직접 임포트할 수 있다.
from collections import OrderedDict # 요소가 추가된 순서를 기억하는 기능이 탑재된 딕셔너리
import numpy as np # 넘파이 연산

import torch # 파이토치 라이브러리
import torch.nn as nn # 신경망 구축에 필요한 클래스와 함수를 제공
import torch.nn.functional as F # 활성화 함수, 손실 함수 등 신경망의 기능적인 부분
from torch_utils import misc # 파이토치
from torch_utils import persistence # 모델 저장 및 로딩 등 모델의 지속성 관련
from torch_utils.ops import conv2d_resample # 2d conv의 리샘플 기능을 제공
from torch_utils.ops import upfirdn2d # 업 샘플링, 다운 샘플링 연산
from torch_utils.ops import bias_act # 활성화 함수에 바이어스를 추가하여 연산

#----------------------------------------------------------------------------

@misc.profiled_function
def normalize_2nd_moment(x, dim=1, eps=1e-8):
    # x: 정규화할 텐서, dim: 정규화를 수행할 차원, eps: 수치적 안정성을 위해 분모에 추가되는 작은 상수.
    # x의 각 요소를 제곱(x.squire())한 뒤 지정된 차원에 대해 평균을 계산한다.
    # keepdim=True는 원래 텐서의 차원을 유지하면서 평균을 계산하는 것을 의미한다.
    # 계산된 평균에 eps를 더해 수치적 안정성을 보장하고, 제곱근의 역수인 rsqret()를 취함.
    # 그리고 이 값을 기존의 x에 곱한 뒤 반환

    # 2차 모멘트 정규화
    # 2차 모멘트 정규화는 텐서의 각 요소가 평균 0과 분산 1을 갖도록 조정합니다.
    # 이는 특히 딥러닝에서 데이터의 스케일을 조정하여 학습 과정의 안정성을 높이고, 수렴 속도를 개선하는 데 도움을 줍니다.
    # 이러한 정규화는 배치 정규화(Batch Normalization)와 유사한 개념이지만, 여기서는 단일 데이터 샘플에 대해 적용됩니다.
    return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()

#----------------------------------------------------------------------------

# '지속성'이란 네트워크 객체의 상태를 저장하고 필요할 때마다 다시 로드할 수 있는 기능을 의미한다. 이는 특히 신경망 모델의 가중치와 같은 중요한 데이터를 저장하고 로드하는데 유용하다.
@persistence.persistent_class
class FullyConnectedLayer(nn.Module):
    # nn.Module을 상속받아 정의된다.
    # 생성자 함수
    def __init__(self,
                 in_features,                # Number of input features.
                 out_features,               # Number of output features.
                 bias            = True,     # Apply additive bias before the activation function?
                 activation      = 'linear', # Activation function: 'relu', 'lrelu', etc.
                 lr_multiplier   = 1,        # Learning rate multiplier. 러닝 레이트 조정을 위한 계수.
                 bias_init       = 0,        # Initial value for the additive bias. 바이어스 초기값 결정.
                 ):
        super().__init__() # nn.Module의 생성자를 호출해 상속받은 변수를 모두 초기화한다.
        # 학습 가능한 매개변수 weight, bias를 초기화. weight는 정규분포를 사용하여 초기화 하고, bias는 모든 값이 bias_init_으로 초기화된다.
        # 또한 활성화 함수도 인자를 받아와 초기화 한다.
        self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
        self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
        self.activation = activation

        # 학습률 조정을 위해 사용되는 변수.
        self.weight_gain = lr_multiplier / np.sqrt(in_features)
        # lr_multiplier를 입력 특성의 수의 제곱근으로 나누어 계산합니다.
        # 이는 가중치 초기화 시 He 초기화 또는 Xavier 초기화와 같은 방법을 따르는 것으로, 각 뉴런의 입력 특성 수에 따라 가중치의 스케일을 조정하여 그래디언트의 분산을 안정화시킵니다.
        self.bias_gain = lr_multiplier

    # 순전파 함수
    def forward(self, x):
        w = self.weight * self.weight_gain # weight_gain은 학습률 조정을 위한 계수. 가중치 조정에 사용됨.
        b = self.bias
        if b is not None and self.bias_gain != 1: # bias 옵션이 있고, bias_gain이 1이 아닌 경우에는 (1인 경우에는 어차피 계산 값이 같으므로)
            b = b * self.bias_gain # 바이어스 조정

        if self.activation == 'linear' and b is not None: # 활성화 함수가 linear이고 b가 존재한다면
            # out = torch.addmm(b.unsqueeze(0), x, w.t())
            x = x.matmul(w.t()) # 웨이트를 곱하고
            out = x + b.reshape([-1 if i == x.ndim-1 else 1 for i in range(x.ndim)]) # b를 x의 차원에 맞게 reshape한 뒤 x에 더한다.
        else: # 활성화 함수가 linear가 아니라면
            x = x.matmul(w.t()) # x에 w를 곱하고
            out = bias_act.bias_act(x, b, act=self.activation, dim=x.ndim-1) # 바이어스 적용 후 활성화 함수 적용
        return out

#----------------------------------------------------------------------------

@persistence.persistent_class
class Conv2dLayer(nn.Module):
    def __init__(self,
                 in_channels,                    # Number of input channels.
                 out_channels,                   # Number of output channels.
                 kernel_size,                    # Width and height of the convolution kernel.
                 bias            = True,         # Apply additive bias before the activation function?
                 activation      = 'linear',     # Activation function: 'relu', 'lrelu', etc.
                 up              = 1,            # Integer upsampling factor.
                 down            = 1,            # Integer downsampling factor.
                 resample_filter = [1,3,3,1],    # Low-pass filter to apply when resampling activations.
                 conv_clamp      = None,         # Clamp the output to +-X, None = disable clamping.
                 trainable       = True,         # Update the weights of this layer during training? (학습 가능 여부)
                 ):
        super().__init__()
        self.activation = activation
        self.up = up # up: 업 샘플링 요소로, 1보다 크면 입력 데이터의 해상도를 높임.
        self.down = down # down: 다운 샘플링 요소로, 1보다 작으면 입력 데이터의 해상도를 낮춤.

        # PyTorch의 버퍼 시스템을 사용하여 resample_filter를 등록합니다. 이 필터는 리샘플링 시 사용됩니다.
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))

        self.conv_clamp = conv_clamp
        self.padding = kernel_size // 2

        # 가중치와 활성화 함수의 스케일 조정 계수
        self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        # 이번에는 가중치를 무작위로 초기화, 바이어스는 0으로 초기화.
        # 커널 사이즈 크기의 웨이트 랜덤 초기화. 이것이 nn.Parameter로 초기화 되었으므로 컨볼루션 필터의 가중치는 학습 가능하다.
        weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size])
        bias = torch.zeros([out_channels]) if bias else None


        # trainable이 True일 경우 커널의 가중치와 바이어스를 학습 가능한 매개변수로 등록.
        if trainable:
            self.weight = torch.nn.Parameter(weight)
            self.bias = torch.nn.Parameter(bias) if bias is not None else None
        # 그렇지 않다면 그냥 버퍼에 저장한다.
        else:
            self.register_buffer('weight', weight)
            if bias is not None:
                self.register_buffer('bias', bias)
            else:
                self.bias = None

    def forward(self, x, gain=1):
        w = self.weight * self.weight_gain
        x = conv2d_resample.conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down,
                                            padding=self.padding)
        # x: 입력 데이터, w: 컨볼루션 가중치, f: 리샘플링 필터(업 샘플링, 혹은 다운 샘플링 시 적용되는 필터로 데이터의 해상도 변경에 영향을 줌)
        # up: 업 샘플링 요소로, 1보다 크면 입력 데이터의 해상도를 높임.
        # down: 다운 샘플링 요소로, 1보다 작으면 입력 데이터의 해상도를 낮춤.
        # padding: 컨볼루션 연산을 수행하기 전에 입력 데이터 주변에 적용할 패딩 크기.

        act_gain = self.act_gain * gain
        # self.conv_clamp가 6이라면, 활성화 함수의 출력값은 -6과 6 사이의 값으로 제한된다.
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        # bias_act 함수는 활성화 함수를 적용하면서 clamp 매개변수를 사용하여 출력값을 제한
        out = bias_act.bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
        return out

#----------------------------------------------------------------------------

@persistence.persistent_class
class ModulatedConv2d(nn.Module):
    def __init__(self,
                 in_channels,                   # Number of input channels.
                 out_channels,                  # Number of output channels.
                 kernel_size,                   # Width and height of the convolution kernel.
                 style_dim,                     # dimension of the style code. 스타일 코드의 크기. 스타일 코드는 신경망의 출력에 영향을 미치는 추가적인 정보를 제공.
                 demodulate=True,               # perfrom demodulation. 디모듈레이션은 가중치를 정규화하는 과정. 각 출력 채널의 가중치 벡터의 크기를 1로 조정해 신경망 안정성을 향상.
                 up=1,                          # Integer upsampling factor.
                 down=1,                        # Integer downsampling factor.
                 resample_filter=[1,3,3,1],     # Low-pass filter to apply when resampling activations. 리샘플링은 입력 데이터의 해상도를 변경하는 과정. 데이터 해상도 변경시 발생하는 왜곡을 최소화 하는 역할.
                 conv_clamp=None,               # Clamp the output to +-X, None = disable clamping.
                 ):
        super().__init__()
        self.demodulate = demodulate

        # 커널 사이즈 크기의 웨이트 랜덤 초기화. 이것이 nn.Parameter로 초기화 되었으므로 컨볼루션 필터의 가중치는 학습 가능하다.
        # 위의 Conv2dLayer 클래스는 커널 가중치의 학습 여부를 trainable 변수를 통해 바꿀 수 있었으나 여기선 그런 옵션은 없다.
        self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]))
        self.out_channels = out_channels # 아웃풋 채널 설정
        self.kernel_size = kernel_size # 커널 사이즈 설정
        self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) # 웨이트 게인 설정
        self.padding = self.kernel_size // 2 # 패딩 설정
        self.up = up # up: 업 샘플링 요소로, 1보다 크면 입력 데이터의 해상도를 높임.
        self.down = down # down: 다운 샘플링 요소로, 1보다 작으면 입력 데이터의 해상도를 낮춤.

        # resample_filter를 버퍼에 저장하는 이유. 버퍼는 학습 가능한 파라미터는 아니나 모델의 일부로 저장되어야 할 텐서를 의미한다.
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
        
        self.conv_clamp = conv_clamp # 활성화 함수를 통과한 값을 클램핑 해주는 임계값.
        
        # 스타일 텐서를 받아 입력 차원의 텐서로 키워주는 전연결층. 나중에 입력 값인 x 텐서와 융합된다.
        self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)

    def forward(self, x, style):
        # 입력 텐서인 x는 x.shape를 통해 배치, 입력 채널, 높이, 너비로 변환된다.
        batch, in_channels, height, width = x.shape
        # style 텐서는 affine이라는 전연결 계층을 통해 x와 동일한 차원으로 생성된 되 x의 차원에 맞게 변형된다.
        style = self.affine(style).view(batch, 1, in_channels, 1, 1)
        # 그리고 웨이트 수정
        weight = self.weight * self.weight_gain * style

        if self.demodulate: # 디모듈레이션 옵션이 있다면
            decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # 웨이트의 각 요소를 제곱한 후, 커널의 차원에 따라 합산하여 정규회 계수를 계산.
            weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1) # 이 계수를 사용하여 가중치를 정규화.

        # 웨이트를 커널 사이즈의 크기로 적절하게 바꿈.
        weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size)
        # x도 컨볼루션 연산을 위해 차원 재배열.
        x = x.view(1, batch * in_channels, height, width)
        # conv2d_resample 함수를 이용해서 x에 weight로 초기화된 컨볼루션 필터를 적용하여 CNN 연산을 진행한다.
        x = conv2d_resample.conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down,
                                            padding=self.padding, groups=batch)
        # 다시 피처 차원으로 x를 reshape.
        out = x.view(batch, self.out_channels, *x.shape[2:])

        return out

#----------------------------------------------------------------------------

@persistence.persistent_class
class StyleConv(torch.nn.Module):
# 입력 텐서에 스타일 정보를 적용하고, 선택적으로 노이즈를 추가한 후, 활성화 함수를 적용하는 레이어.
# 스타일 전이나 GAN과 같은 분야에 사용된다.
    def __init__(self,
        in_channels,                    # Number of input channels.
        out_channels,                   # Number of output channels.
        style_dim,                      # Intermediate latent (W) dimensionality.
        resolution,                     # Resolution of this layer.
        kernel_size     = 3,            # Convolution kernel size.
        up              = 1,            # Integer upsampling factor.
        use_noise       = True,         # Enable noise input?
        activation      = 'lrelu',      # Activation function: 'relu', 'lrelu', etc.
        resample_filter = [1,3,3,1],    # Low-pass filter to apply when resampling activations.
        conv_clamp      = None,         # Clamp the output of convolution layers to +-X, None = disable clamping.
        demodulate      = True,         # perform demodulation
    ):
        super().__init__()

        # 위에서 살펴본 ModulatedConv2d를 이용하여 conv 객체를 생성.
        self.conv = ModulatedConv2d(in_channels=in_channels,
                                    out_channels=out_channels,
                                    kernel_size=kernel_size,
                                    style_dim=style_dim,
                                    demodulate=demodulate,
                                    up=up,
                                    resample_filter=resample_filter,
                                    conv_clamp=conv_clamp)

        self.use_noise = use_noise # 노이즈 생성 여부
        self.resolution = resolution # 해당 레이어의 해상도

        if use_noise: # 노이즈를 생성할 것이라면
            # 해상도x해상도 사이즈의 노이즈 텐서는 학습할 파라미터가 아니므로 레지스터 버퍼에 저장한다.
            self.register_buffer('noise_const', torch.randn([resolution, resolution]))
            # 노이즈의 강도는 학습 가능한 파라미터로 정한다.
            self.noise_strength = torch.nn.Parameter(torch.zeros([]))
        
        self.bias = torch.nn.Parameter(torch.zeros([out_channels])) # 학습 가능한 파라미터인 바이어스
        self.activation = activation # 활성화 함수
        self.act_gain = bias_act.activation_funcs[activation].def_gain # 활성화 함수의 강도 조절
        self.conv_clamp = conv_clamp # 클램핑 임계값

    def forward(self, x, style, noise_mode='random', gain=1):
        x = self.conv(x, style) # 우선 x와 style 텐서를 받아 모듈레이티드 컨볼루션 통과

        assert noise_mode in ['random', 'const', 'none'] # 노이즈 옵션이 셋 중 하나인지 확인

        if self.use_noise: # 노이즈를 활용할 거라면
            if noise_mode == 'random': # 노이즈 옵션이 random일 때
                xh, xw = x.size()[-2:] # 노이즈 무작위 설정
                noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \
                        * self.noise_strength
            if noise_mode == 'const': # 노이즈 옵션이 const일 때
                noise = self.noise_const * self.noise_strength # const에 strenght를 곱하는 것으로 끝
            x = x + noise # 노이즈를 더한다

        act_gain = self.act_gain * gain # act_gain 곱해주고
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None #클램핑 한 뒤
        out = bias_act.bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) # 아웃풋 출력

        return out

#----------------------------------------------------------------------------

@persistence.persistent_class
class ToRGB(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 style_dim,
                 kernel_size=1,
                 resample_filter=[1,3,3,1],
                 conv_clamp=None,
                 demodulate=False):
        super().__init__()

        self.conv = ModulatedConv2d(in_channels=in_channels,
                                    out_channels=out_channels,
                                    kernel_size=kernel_size,
                                    style_dim=style_dim,
                                    demodulate=demodulate,
                                    resample_filter=resample_filter,
                                    conv_clamp=conv_clamp)
        self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
        self.conv_clamp = conv_clamp

    def forward(self, x, style, skip=None):
        x = self.conv(x, style) # 우선 x와 style 텐서를 modulated conv에 통과시킴.
        # bias_act: 입력 텐서에 바이어스를 추가하고 활성화 함수를 적용한 뒤, 결과를 클램핑하여 제한한다.
        out = bias_act.bias_act(x, self.bias, clamp=self.conv_clamp)

        if skip is not None:
            if skip.shape != out.shape:
                skip = upfirdn2d.upsample2d(skip, self.resample_filter)
            out = out + skip
        # out을 RGB의 3차원으로 만들어서 ToRGB라는 이름이 붙은 것 같다.
        return out

#----------------------------------------------------------------------------

@misc.profiled_function
def get_style_code(a, b):
    return torch.cat([a, b], dim=1)

#----------------------------------------------------------------------------

@persistence.persistent_class
class DecBlockFirst(nn.Module):
# 디코더 부분의 첫 번째 블록. 입력 텐서에 스타일 코드를 적용한 뒤 최종적으로 RGB 이미지를 생성하는 것.
    def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
        super().__init__()
        # x를 전연결 계층에 통과시켜 더 높은 차원으로 매핑한다.
        self.fc = FullyConnectedLayer(in_features=in_channels*2,
                                      out_features=in_channels*4**2,
                                      activation=activation)
        
        # 스타일 코드를 기반으로 컨볼루션 연산을 수행. 이를 통해 텐서에 스타일 정보를 통합하여 더 풍부한 특징을 추출.
        self.conv = StyleConv(in_channels=in_channels,
                              out_channels=out_channels,
                              style_dim=style_dim,
                              resolution=4,
                              kernel_size=3,
                              use_noise=use_noise,
                              activation=activation,
                              demodulate=demodulate,
                              )
        
        # RGB 채널에 해당하는 출력으로 변환.
        self.toRGB = ToRGB(in_channels=out_channels,
                           out_channels=img_channels,
                           style_dim=style_dim,
                           kernel_size=1,
                           demodulate=False,
                           )

    def forward(self, x, ws, gs, E_features, noise_mode='random'):
        x = self.fc(x).view(x.shape[0], -1, 4, 4)
        x = x + E_features[2] # 인코더의 특징을 skip connection으로 더한다.
        style = get_style_code(ws[:, 0], gs)
        x = self.conv(x, style, noise_mode=noise_mode)
        style = get_style_code(ws[:, 1], gs)
        img = self.toRGB(x, style, skip=None)

        return x, img


@persistence.persistent_class
class DecBlockFirstV2(nn.Module):
    def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
        super().__init__()
        # 기본 컨볼루션
        self.conv0 = Conv2dLayer(in_channels=in_channels,
                                out_channels=in_channels,
                                kernel_size=3,
                                activation=activation,
                                )
        # 스타일 컨볼루션                        
        self.conv1 = StyleConv(in_channels=in_channels,
                              out_channels=out_channels,
                              style_dim=style_dim,
                              resolution=4,
                              kernel_size=3,
                              use_noise=use_noise,
                              activation=activation,
                              demodulate=demodulate,
                              )
        self.toRGB = ToRGB(in_channels=out_channels,
                           out_channels=img_channels,
                           style_dim=style_dim,
                           kernel_size=1,
                           demodulate=False,
                           )

    def forward(self, x, ws, gs, E_features, noise_mode='random'):
        # x = self.fc(x).view(x.shape[0], -1, 4, 4)
        x = self.conv0(x)
        x = x + E_features[2]
        style = get_style_code(ws[:, 0], gs)
        x = self.conv1(x, style, noise_mode=noise_mode)
        style = get_style_code(ws[:, 1], gs)
        img = self.toRGB(x, style, skip=None)

        return x, img

#----------------------------------------------------------------------------

@persistence.persistent_class
class DecBlock(nn.Module):
    def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):  # res = 2, ..., resolution_log2
        super().__init__()
        self.res = res

        self.conv0 = StyleConv(in_channels=in_channels,
                               out_channels=out_channels,
                               style_dim=style_dim,
                               resolution=2**res,
                               kernel_size=3,
                               up=2,
                               use_noise=use_noise,
                               activation=activation,
                               demodulate=demodulate,
                               )
        self.conv1 = StyleConv(in_channels=out_channels,
                               out_channels=out_channels,
                               style_dim=style_dim,
                               resolution=2**res,
                               kernel_size=3,
                               use_noise=use_noise,
                               activation=activation,
                               demodulate=demodulate,
                               )
        self.toRGB = ToRGB(in_channels=out_channels,
                           out_channels=img_channels,
                           style_dim=style_dim,
                           kernel_size=1,
                           demodulate=False,
                           )

    def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
        style = get_style_code(ws[:, self.res * 2 - 5], gs)
        x = self.conv0(x, style, noise_mode=noise_mode)
        x = x + E_features[self.res]
        style = get_style_code(ws[:, self.res * 2 - 4], gs)
        x = self.conv1(x, style, noise_mode=noise_mode)
        style = get_style_code(ws[:, self.res * 2 - 3], gs)
        img = self.toRGB(x, style, skip=img)

        return x, img

#----------------------------------------------------------------------------

@persistence.persistent_class
class MappingNet(torch.nn.Module):
# 입력된 잠재 벡터(latent vector) z와 조건 벡터(conditioning vector) c를 받아 중간 잠재 공간(intermediate latent space) w로 변환.
# 이 과정은 GAN의 스타일 변조(style modulation)에 사용된다.
    def __init__(self,
        z_dim,                      # Input latent (Z) dimensionality, 0 = no latent. 입력 잠재 벡터 차원
        c_dim,                      # Conditioning label (C) dimensionality, 0 = no label. 조건 레이블 차원, 0이면 레이블 없음을 의미
        w_dim,                      # Intermediate latent (W) dimensionality. 중간 잠재 벡터 차원
        num_ws,                     # Number of intermediate latents to output, None = do not broadcast. 출력할 중간 잼재 벡터의 수, None이면 broadcast 하지 않음.
        num_layers      = 8,        # Number of mapping layers. 매핑 레이어 수.
        embed_features  = None,     # Label embedding dimensionality, None = same as w_dim. 레이블 임베딩 차원. None이면 w_dim과 동일.
        layer_features  = None,     # Number of intermediate features in the mapping layers, None = same as w_dim. 매핑 레이어의 중간 특징 수. None이면 w_dim과 동일.
        activation      = 'lrelu',  # Activation function: 'relu', 'lrelu', etc. 활성화 함수.
        lr_multiplier   = 0.01,     # Learning rate multiplier for the mapping layers. 매핑 레이어의 학습률 조정 인자.
        w_avg_beta      = 0.995,    # Decay for tracking the moving average of W during training, None = do not track. 중간 잠재 벡터 W의 이동 평균 감쇠율. None이면 추적 안 함.
    ):
        # 클래스 변수 초기화.
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.num_ws = num_ws
        self.num_layers = num_layers
        self.w_avg_beta = w_avg_beta

        if embed_features is None:
            embed_features = w_dim # 우선 w_dim으로 초기화
        if c_dim == 0:
            embed_features = 0 # 클래스 레이블이 없으면 0으로 업데이트
        if layer_features is None:
            layer_features = w_dim # 매핑 레이어의 중간 특징 수가 없다면 w_dim으로 초기화.

        features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
        # [z_dim + embed_features]: 입력 잠재 벡터 z와 c가 FCN을 거쳐 임베딩 된 차원을 더하는 첫 번째 레이어
        # [layer_features] * (num_layers - 1): layer_features는 중간 레이어의 특징 수. 이걸 첫 번째 레이어를 제외한 나머지 개수에 곱해준다.
        # [w_dim]: 매핑 네트워크의 마지막 레이어의 출력 차원.

        # 클래스 레이블이 있다면
        if c_dim > 0: # 클래스 레이블 디멘션을 임베딩 피처 디멘션으로 늘리는 FCN을 설정
            self.embed = FullyConnectedLayer(c_dim, embed_features)
        for idx in range(num_layers): # 레이어 개수를 돌면서
            in_features = features_list[idx] # 입력 차원
            out_features = features_list[idx + 1] # 출력 차원
            layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) # FCN 레이어를 만든다.
            setattr(self, f'fc{idx}', layer) # setattr 함수를 이용하여 각 레이어를 동적으로 클래스 인스턴스로 할당.

        # 중간 잠재 벡터의 이동 평균 버퍼 초기화
        if num_ws is not None and w_avg_beta is not None:
            self.register_buffer('w_avg', torch.zeros([w_dim]))

    # truncation_psi, truncation_cutoff는 스타일 변조를 위한 매개변수이다.
    # skip_w_avg_update는 중간 잠재 벡터의 이동 평균 업데이트를 건너 뛸지 결정한다.
    def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
        # Embed, normalize, and concat inputs.
        # 입력 잠재 벡터와 조건 벡터를 정규화한 뒤, 조건 벡터가 존재할 경우 이들을 결합(concatenation) 한다.

        x = None
        with torch.autograd.profiler.record_function('input'):
            if self.z_dim > 0: # 입력된 잠재 벡터 z가 주어졌다면 이를 정규화 한다.
                x = normalize_2nd_moment(z.to(torch.float32))
            if self.c_dim > 0: # 조건 벡터 c가 주어진 경우 이를 FCN에 통과시켜 임베딩 한 후 정규화한 뒤, x와 결합한다.
                y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
                x = torch.cat([x, y], dim=1) if x is not None else y

        # Main layers.
        for idx in range(self.num_layers): # 레이어 개수를 순회하면서
            layer = getattr(self, f'fc{idx}') # self.fc0, self.fc1, ... 이런 식으로 레이어를 가져온다.
            x = layer(x) # 각 레이어를 순차적으로 통과시키며 중간 잠재 공간의 벡터로 변환한다.

        # Update moving average of W.
        # 모델 훈련 중이고, 이동 평균 업데이트를 건너뛰지 않는 경우, 중간 잠재 벡터 w의 이동 평균을 업데이트 한다.
        # 이는 학습 과정에서 잠재 공간의 안정성을 향상시키는데 도움이 된다.
        if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
            with torch.autograd.profiler.record_function('update_w_avg'):
                self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))

        # Broadcast.
        # 중간 잠재 벡터 w를 필요한 수만큼 복제하여 네트워크의 다음 단계로 전달할 준비를 한다.
        if self.num_ws is not None: # self.num_ws는 브로드캐스팅 할 잠재 벡터의 수를 의미.
            with torch.autograd.profiler.record_function('broadcast'): # profiler는 성능을 기록하는 기능
                x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
                # x.unsqueeze(1): x.shape = [batch_size, 1, w_dim]
                # repeat([1, self.num_ws, 1]): x를 self.num_ws만큼 복제
                # x.shape = [batch_size, self.num_ws, w_dim] 형태의 텐서가 됨.

        # Apply truncation.
        # 스타일 변조(truncation)을 적용하여 출력 벡터 w를 조정한다.
        # 이는 생성된 이미지의 다양성과 현실성 사이의 균형을 맞춘다.
        if truncation_psi != 1:
            with torch.autograd.profiler.record_function('truncate'):
                assert self.w_avg_beta is not None
                if self.num_ws is None or truncation_cutoff is None:
                    x = self.w_avg.lerp(x, truncation_psi)
                else:
                    x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)

        return x

#----------------------------------------------------------------------------

@persistence.persistent_class
class DisFromRGB(nn.Module):
# 판별기 네트워크의 일부로, RGB 이미지를 입력으로 받아 네트워크 초기 단계에서 처리하는 것.
    def __init__(self, in_channels, out_channels, activation):  # res = 2, ..., resolution_log2
        super().__init__()
        # 신기하게 1x1 커널을 가진 CNN 필터를 만든다.
        # 이러한 이유는 채널간의 정보를 효과적으로 통합하기 위함이다.
        self.conv = Conv2dLayer(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=1,
                                activation=activation,
                                )
    def forward(self, x):
        return self.conv(x)

#----------------------------------------------------------------------------

@persistence.persistent_class
class DisBlock(nn.Module):
# 판별기 네트워크에서 피처 맵을 처리하고 다운 샘플링을 통해 해상도를 줄이는 역할을 한다.
    def __init__(self, in_channels, out_channels, activation):  # res = 2, ..., resolution_log2
        super().__init__()
        # 입력 채널을 그대로 유지하면서 특징만 추출
        self.conv0 = Conv2dLayer(in_channels=in_channels,
                                 out_channels=in_channels,
                                 kernel_size=3,
                                 activation=activation,
                                 )
        # down=2를 통해 해상도를 2배로 줄임
        self.conv1 = Conv2dLayer(in_channels=in_channels,
                                 out_channels=out_channels,
                                 kernel_size=3,
                                 down=2,
                                 activation=activation,
                                 )
        # 커널 사이즈가 1이고, 다시 down을 통해 해상도를 2배로 줄임
        self.skip = Conv2dLayer(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=1,
                                down=2,
                                bias=False,
                             )

    def forward(self, x):
        skip = self.skip(x, gain=np.sqrt(0.5))
        x = self.conv0(x)
        x = self.conv1(x, gain=np.sqrt(0.5))
        out = skip + x

        return out

#----------------------------------------------------------------------------

@persistence.persistent_class
class MinibatchStdLayer(torch.nn.Module):
# 판별기 입력에 미니배치 내의 샘플 간 다양성에 대한 정보를 제공한다.
# 이는 판별기가 미니배치 내의 샘플들이 얼마나 다른지 파악하고, 이를 통해 실제 데이터와 생성된 데이터 간의 차이를 더 잘 구별하도록 만든다.
    def __init__(self, group_size, num_channels=1):
        super().__init__()
        self.group_size = group_size
        self.num_channels = num_channels

    def forward(self, x):
        N, C, H, W = x.shape
        with misc.suppress_tracer_warnings():  # as_tensor results are registered as constants
            G = torch.min(torch.as_tensor(self.group_size),
                          torch.as_tensor(N)) if self.group_size is not None else N
        F = self.num_channels
        c = C // F

        y = x.reshape(G, -1, F, c, H,
                      W)  # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
        y = y - y.mean(dim=0)  # [GnFcHW] Subtract mean over group.
        y = y.square().mean(dim=0)  # [nFcHW]  Calc variance over group.
        y = (y + 1e-8).sqrt()  # [nFcHW]  Calc stddev over group.
        y = y.mean(dim=[2, 3, 4])  # [nF]     Take average over channels and pixels.
        y = y.reshape(-1, F, 1, 1)  # [nF11]   Add missing dimensions.
        y = y.repeat(G, 1, H, W)  # [NFHW]   Replicate over group and pixels.
        x = torch.cat([x, y], dim=1)  # [NCHW]   Append to input as new channels.
        return x

#----------------------------------------------------------------------------

@persistence.persistent_class
class Discriminator(torch.nn.Module):
    def __init__(self,
                 c_dim,                         # Conditioning label (C) dimensionality.
                 img_resolution,                # Input resolution.
                 img_channels,                  # Number of input color channels.
                 channel_base       = 32768,    # Overall multiplier for the number of channels.
                 channel_max        = 512,      # Maximum number of channels in any layer.
                 channel_decay      = 1,
                 cmap_dim           = None,     # Dimensionality of mapped conditioning label, None = default.
                 activation         = 'lrelu',
                 mbstd_group_size   = 4,        # Group size for the minibatch standard deviation layer, None = entire minibatch.
                 mbstd_num_channels = 1,        # Number of features for the minibatch standard deviation layer, 0 = disable.
                 ):
        super().__init__()
        self.c_dim = c_dim # 클래스 디멘션
        self.img_resolution = img_resolution # 이미지 해상도
        self.img_channels = img_channels # 입력 이미지 채널

        # 입력 이미지의 해상도에 따라 네트워크의 구조를 동적으로 다르게 구성하기 위해 해상도의 로그 값을 구한다.
        resolution_log2 = int(np.log2(img_resolution))
        assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
        self.resolution_log2 = resolution_log2

        # 네트워크의 초기 단계에서는 더 많은 채널을 사용하여 복잡한 특징을 추출하고, 깊은 단계에서는 채널 수를 줄여 계산 부담을 감소시킨다.
        def nf(stage):
            return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max)

        # 클래스 레이블을 처리하기 위한 중간 차원의 크기
        if cmap_dim == None:
            cmap_dim = nf(2)
        if c_dim == 0:
            cmap_dim = 0
        self.cmap_dim = cmap_dim

        # 클래스 레이블이 있을 경우 매핑 네트워크 초기화
        if c_dim > 0:
            self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)

        # 판별기 초기의 피처 추출 및 다운 샘플링 레이어를 초기화
        Dis = [DisFromRGB(img_channels+1, nf(resolution_log2), activation)]
        for res in range(resolution_log2, 2, -1):
            Dis.append(DisBlock(nf(res), nf(res-1), activation))

        # 미니 배치의 표준편차 정보를 반영하는 레이어와 CNN 레이어를 추가로 삽입
        if mbstd_num_channels > 0:
            Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
        Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
        self.Dis = nn.Sequential(*Dis)

        # FCN 레이어를 초기화
        self.fc0 = FullyConnectedLayer(nf(2)*4**2, nf(2), activation=activation)
        self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)

    def forward(self, images_in, masks_in, c):
        x = torch.cat([masks_in - 0.5, images_in], dim=1) # 마스크와 이미지를 채널 차원을 따라 연결
        x = self.Dis(x) # 판별자 레이어를 통과
        x = self.fc1(self.fc0(x.flatten(start_dim=1))) # FNC 레이어 통과 ouput은 1이다.

        if self.c_dim > 0: # 조건 레이블이 있는 경우 매핑 네트워크 통과
            cmap = self.mapping(None, c)

        if self.cmap_dim > 0: # 조건 레이블이 있는 경우 이 정보를 판별자의 출력에 통합
            x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))

        return x

'코드 분석 > MAT' 카테고리의 다른 글

MAT 환경설정 및 버그 리포트  (0) 2024.01.22
MAT: mat.py  (1) 2024.01.11
MAT: training_loop.py  (0) 2023.12.28
MAT: train.py  (1) 2023.12.27
MAT: generate_image.py  (0) 2023.12.26