코드 분석/Edge-Connect

Edge-Connect: networks.py

상솜공방 2024. 1. 24. 11:45

알아둬야 할 지식

더보기

self.apply(init_func)의 역할: self.apply 메서드는 PyTorch의 nn.Module 클래스에 정의된 메서드로, 주어진 함수(init_func)를 현재 모듈(self)의 모든 하위 모듈에 재귀적으로 적용합니다. 즉, init_func 함수는 네트워크의 모든 레이어에 대해 호출되어 가중치 초기화 등의 작업을 수행합니다.

init_func(m)에서 m의 의미: m은 현재 모듈의 하위 요소를 나타냅니다. self.apply 메서드가 init_func를 호출할 때, 네트워크의 각 레이어(하위 모듈)가 m으로 전달됩니다. m은 컨볼루션 레이어, 배치 정규화 레이어 등 신경망의 구성 요소를 나타낼 수 있습니다.

classname = m.class.__name__의 의미: m.__class__는 m 객체의 클래스를 나타냅니다. __name__ 속성은 클래스의 이름을 문자열로 반환합니다. 따라서 classname은 현재 처리 중인 모듈의 클래스 이름을 문자열로 저장합니다.

함수 내부에 함수를 정의하는 경우: 함수 내부에 다른 함수를 정의하는 것을 '내부 함수' 또는 '지역 함수'라고 합니다. 이는 특정 함수 내에서만 필요한 기능을 캡슐화하고, 외부에서의 접근을 제한하기 위해 사용됩니다. init_func는 init_weights 함수 내에서만 사용되므로 내부 함수로 정의됩니다.

**self.middle = nn.Sequential(blocks)에서 의 의미: * 연산자는 파이썬에서 '언패킹' 연산자로 사용됩니다. 이는 blocks 리스트의 각 요소를 nn.Sequential 생성자에 개별 인자로 전달합니다. 즉, 리스트의 요소들을 별도의 인자로 분리하여 함수에 전달합니다.

nn.ConvTranspose2d 메서드: nn.ConvTranspose2d는 2D 전치 컨볼루션(때로는 역 컨볼루션 또는 디컨볼루션으로 불림)을 수행하는 레이어입니다. 이는 이미지의 공간적 차원을 확장하는 데 사용되며, 일반적으로 이미지 생성 작업에 사용됩니다.

nn.InstanceNorm2d 메서드: nn.InstanceNorm2d는 인스턴스 정규화를 수행하는 레이어입니다. 이는 배치 내 각 이미지의 각 채널에 대해 독립적으로 정규화를 수행합니다. 이는 스타일 전송과 같은 작업에서 유용하게 사용됩니다.

nn.ReflectionPad2d 메서드: nn.ReflectionPad2d는 입력의 가장자리를 반사하여 패딩하는 레이어입니다. 이는 입력 이미지의 가장자리를 반사하여 패딩을 추가함으로써 이미지의 크기를 조정합니다.

x = (torch.tanh(x) + 1) / 2의 목적: 이는 Tanh 활성화 함수의 출력을 [0, 1] 범위로 조정하는 데 사용됩니다. Tanh 함수는 [-1, 1] 범위의 값을 출력하는데, 이를 [0, 1] 범위로 조정하여 이미지 데이터와 호환되도록 만듭니다. 이는 이미지 데이터가 보통 0과 1 사이의 값을 가지기 때문에 필요합니다.

 

Tanh 대신 Sigmoid 사용 여부: 실제로 Sigmoid 함수는 출력을 [0, 1] 범위로 조정하는 데 자주 사용됩니다. 그러나 Tanh 함수를 사용하는 경우도 있습니다. Tanh는 출력이 [-1, 1] 범위이며, 때때로 네트워크가 학습하는 데 더 유리할 수 있습니다. Tanh는 중심이 0이기 때문에, 학습 초기에 더 균형 잡힌 그래디언트를 제공할 수 있습니다. (torch.tanh(x) + 1) / 2는 Tanh의 출력을 [0, 1] 범위로 조정합니다. 선택은 특정 작업의 요구 사항과 네트워크 아키텍처에 따라 달라질 수 있습니다.

Spectral Norm의 정의: Spectral Norm은 신경망의 가중치 행렬에 적용되는 정규화 기법 중 하나입니다. 이는 가중치 행렬의 스펙트럼(고유값)을 기반으로 정규화하여, 훈련 중에 네트워크의 레이어가 너무 극단적인 값을 출력하지 않도록 합니다. 이는 특히 GAN(Generative Adversarial Networks)에서 발산을 방지하고 안정적인 학습을 돕는 데 유용합니다.

EdgeGenerator에서 Spectral Norm 사용 이유: EdgeGenerator에서 Spectral Norm을 사용하는 이유는 네트워크가 안정적으로 학습되도록 하기 위함입니다. GAN과 같은 모델에서는 생성자와 판별자 간의 균형을 유지하는 것이 중요한데, Spectral Norm은 이 균형을 유지하는 데 도움을 줄 수 있습니다. 특히 에지 생성과 같은 섬세한 작업에서는 출력의 안정성이 중요할 수 있습니다.

InpaintGenerator와 EdgeGenerator에서 활성화 함수 선택: InpaintGenerator와 EdgeGenerator에서 서로 다른 활성화 함수(Tanh와 Sigmoid)를 사용하는 이유는 각 네트워크의 목적과 출력의 특성에 따라 다릅니다. InpaintGenerator는 이미지의 픽셀 값을 생성하는 데 사용되며, Tanh 함수는 이 경우 더 균형 잡힌 출력을 제공할 수 있습니다. 반면, EdgeGenerator는 이미지의 에지를 나타내는 이진 형태의 출력을 생성하는 데 사용되며, Sigmoid 함수는 이러한 이진 분류 작업에 더 적합합니다. Sigmoid는 출력을 [0, 1] 범위로 제한하여, 에지가 있거나 없음을 나타내는 데 유용합니다.

 

LeakyReLU 사용 이유: LeakyReLU (Leaky Rectified Linear Unit)는 일반 ReLU 활성화 함수의 변형으로, 입력 값이 음수일 때도 작은 기울기를 허용합니다 (ReLU는 입력이 음수일 때 0을 출력합니다). LeakyReLU를 사용하는 주된 이유는 다음과 같습니다:

소실된 그래디언트 문제 완화: 신경망이 깊어질수록 ReLU를 사용하면 음수 입력에 대해 그래디언트가 0이 되어 역전파 동안 그래디언트가 소실될 수 있습니다. LeakyReLU는 이러한 문제를 완화하여 네트워크가 더 깊어질 때도 효과적으로 학습할 수 있도록 돕습니다.


비선형성 유지: LeakyReLU는 비선형 활성화 함수이므로, 신경망이 복잡한 패턴을 학습하는 데 도움이 됩니다.
[conv1, conv2, conv3, conv4, conv5] 반환 이유: 판별자에서 각 컨볼루션 레이어의 출력을 반환하는 것은 일반적으로 다음과 같은 이유로 사용됩니다:

특징 시각화: 각 레이어의 출력을 분석하면 네트워크가 어떤 특징을 학습하고 있는지 이해할 수 있습니다. 이는 네트워크의 동작을 시각화하고 해석하는 데 유용합니다.
중간 레이어의 정보 활용: 특정 응용 프로그램에서는 중간 레이어의 출력이 유용한 정보를 포함할 수 있습니다. 예를 들어, 일부 GAN 구조에서는 중간 레이어의 출력을 사용하여 생성자의 학습을 안내합니다.
디버깅 및 분석: 네트워크의 각 부분이 어떻게 동작하는지 이해하고, 문제가 있는 부분을 식별하는 데 도움이 됩니다.

import torch
import torch.nn as nn

class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__() # 상위 클래스인 nn.Module 생성자를 호출하여 초기화.
    
    def init_weights(self, init_type='normal', gain=0.02):
        # 가중치 초기화 메서드 정의. init_type은 초기화 유형을, gain은 초기화에 사용되는 스케일 팩터.
        '''
        initialize network's weights
        init_type: normal | xavier | kaiming | orthogonal
        https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
        '''
        # 가중치 초기화를 수행하는 내부 함수.
        # 해당 함수는 init_weigts() 함수 내에서만 사용되므로 이렇게 내부에서 선언하였다.
        def init_func(m): # m은 각 모듈의 하위 요소, 즉 신경망 레이어들이다.
            classname = m.__class__.__name__ # m.__class__: 어떤 클래스인지, m.__class__.__name__: 클래스의 이름을 문자열로 반환.
            # m이 가중치를 가지고 있고, Conv 혹은 Linear를 포함하는가?
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                # 컨볼루션 혹은 리니어 네트워크의 웨이트를 초기화.
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)
                # 컨볼루션, 혹은 리니어 네트워크의 바이어스를 초기화.
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            # 컨볼루션이나 리니어가 아니라, 배치 정규화 레이어라면, 웨이트와 바이어스를 정규화 한다.
            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, gain)
                nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func) # init_func 함수를 현재 모듈의 모든 하위 모듈에 적용하여 가중치를 초기화 한다.
        # apply() 메서드는 주어진 함수를 현재 모듈(self)의 모든 하위 모듈에 재귀적으로 적용하는 함수이다.

class InpaintGenerator(BaseNetwork): # 위에서 정의한 BaseNetwork를 상속받는다.
    def __init__(self, residual_blocks=8, init_weights=True):
        # 추가적으로 받는 변수는 잔차 블록의 수와 가중치 초기화 여부이다.
        super(InpaintGenerator, self).__init__()

        # 인코딩 블록. 이미지를 저차원 특징 공간으로 변환한다. 첫 번째 블록의 인풋 차원이 4차원임을 짚고 넘어갈 것!
        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True)
        )

        # 잔차 블록을 저장할 리스트를 생성하고, 여기에 이들을 추가함.
        blocks = []
        for _ in range(residual_blocks): # 기본 값은 8이다.
            block = ResnetBlock(256, 2) # 인풋 채널은 256, dilation은 2인 간단한 잔차 블록 객체를 생성해 blocks에 넣기.
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)
        # 여기서 *는 '언패킹 연산자'로 blocks 리스트의 각 요소를 별도의 인자로 분리하여 nn.Sequential() 함수에 개별적으로 전달한다.

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), # 전치 컨볼루션 = 역컨볼루션 = 디컨볼루션
            nn.InstanceNorm2d(128, track_running_stats=False), # 인스턴스 정규화: 배치 내 각 이미지 채널에 대해 독립적으로 정규화를 수행한다.
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),
            
            nn.ReflectionPad2d(3), # 입력의 가장자리를 반사하여 패딩. 이미지 크기를 조절한다.
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
        )

        if init_weights:
            self.init_weights()

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = (torch.tanh(x) + 1) / 2 # tanh() 함수는 (-1, 1)까지의 값을 갖는데, 이 값을 [0, 1] 사이로 줄이기 위함이다.

        return x


class EdgeGenerator(BaseNetwork):
    def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True):
        super(EdgeGenerator, self).__init__()

        # 여기선 input channel이 3이다.
        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True)
        )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
            blocks.append(block) # input, output dim이 256이고, dilation은 2, use_spectral_norm=True.

        self.middle = nn.Sequential(*blocks)

        self.decoder = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
        )

        if init_weights:
            self.init_weights()

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = torch.sigmoid(x)
        return x


class Discriminator(BaseNetwork):
    def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True):
        super(Discriminator, self).__init__()
        self.use_sigmoid = use_sigmoid

        self.conv1 = self.features = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv2 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv3 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv4 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv5 = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
        )

        if init_weights:
            self.init_weights()

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        outputs = conv5
        if self.use_sigmoid:
            outputs = torch.sigmoid(conv5)

        return outputs, [conv1, conv2, conv3, conv4, conv5]


class ResnetBlock(nn.Module):
    def __init__(self, dim, dilation=1, use_spectral_norm=False):
        # dim: 컨볼루션 레이어 채널 수, dilation: 컨볼루션의 dilation 값, use_spectral_norm: 스펙트럴 정규화 사용 여부.
        super(ResnetBlock, self).__init__() # 상위 생성자 호출.
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(dilation),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
        )

    def forward(self, x):
        out = x + self.conv_block(x) # 잔차 연결.

        # Remove ReLU at the end of the residual block
        # http://torch.ch/blog/2016/02/04/resnets.html

        return out


def spectral_norm(module, mode=True):
    if mode:
        return nn.utils.spectral_norm(module)

    return module