코드 분석/Edge-Connect

Edge-Connect: loss.py

상솜공방 2024. 1. 24. 16:56

알아둬야 할 지식

더보기

self.register_buffer 사용 이유: self.register_buffer는 PyTorch의 nn.Module에서 제공하는 메서드로, 모듈의 상태(state)에 텐서를 저장하지만, 이 텐서는 모델 파라미터로 간주되지 않습니다. 즉, 이 텐서는 학습 과정에서 업데이트되지 않습니다.
target_real_label을 텐서로 변환하고 register_buffer를 사용하여 저장하는 이유는, 이 값을 모델의 일부로 유지하면서도 역전파(backpropagation)에서는 무시하기 위함입니다. 이렇게 함으로써, 이 값은 모델 저장 및 로딩 과정에서 자동으로 처리되며, GPU 연산에도 사용될 수 있습니다.
NSGAN, LSGAN, Hinge Loss의 차이:

NSGAN (Non-Saturating GAN): 이진 교차 엔트로피(Binary Cross Entropy, BCE) 손실을 사용합니다. 이는 생성자가 판별자를 속이려고 할 때 효과적인 그래디언트를 제공합니다.
LSGAN (Least Squares GAN): 평균 제곱 오차(Mean Squared Error, MSE) 손실을 사용합니다. 이는 판별자의 출력이 타겟 레이블과의 차이의 제곱으로 계산되며, 판별자의 결정 경계 주변에서 더 부드러운 그래디언트를 제공합니다.
Hinge Loss: 이는 주로 판별자에 사용되며, 실제 이미지에 대해서는 -min(0, -1 + D(x)), 가짜 이미지에 대해서는 -min(0, -1 - D(G(z)))를 계산합니다. 이는 판별자가 더 확실한 결정을 내리도록 유도합니다.


call 메서드의 의미: __call__ 메서드는 파이썬의 특수 메서드로, 객체를 함수처럼 호출할 수 있게 해줍니다. 즉, AdversarialLoss 클래스의 인스턴스를 함수처럼 사용할 수 있게 해주며, 이 때 __call__ 메서드가 실행됩니다.
is_disc의 의미: is_disc는 메서드가 판별자(Discriminator)에 의해 호출되는지 여부를 나타냅니다. GAN에서는 판별자와 생성자(Generator)가 서로 다른 목적을 가지고 있으며, is_disc는 현재 손실 계산이 판별자를 위한 것인지를 구분하는 데 사용됩니다.
Hinge Loss에서 실제 이미지에 -값 적용 이유: Hinge Loss에서 실제 이미지에 대한 판별자의 출력 D(x)는 1보다 크게 만들려는 목표를 가집니다. -outputs와 ReLU(1 + outputs)를 사용함으로써, 판별자가 실제 이미지에 대해 1보다 큰 값을 출력하도록 유도합니다. 이는 판별자가 실제 이미지와 가짜 이미지를 더 명확하게 구분하도록 돕습니다.

 

클래스 내부 변수와 인스턴스 변수

클래스 내부에는 인스턴스 변수와는 다른 유형의 변수들이 존재할 수 있습니다. 이들은 주로 클래스 변수라고 불리며, 클래스 정의 내에서 self를 사용하지 않고 선언됩니다. 인스턴스 변수와 클래스 변수 간의 주요 차이점은 다음과 같습니다:

인스턴스 변수:
인스턴스 변수는 클래스의 각 인스턴스(객체)에 속한 변수입니다.
이 변수들은 객체가 생성될 때마다 각 인스턴스에 대해 별도로 존재합니다.
인스턴스 변수는 self 키워드를 사용하여 클래스의 메서드 내에서 정의되며, 각 인스턴스의 상태를 저장하는 데 사용됩니다.
예: self.name = "example"


클래스 변수:
클래스 변수는 클래스 자체에 속한 변수로, 모든 인스턴스 간에 공유됩니다.
클래스 변수는 클래스 정의 내에서 self를 사용하지 않고 선언됩니다.
클래스 변수는 해당 클래스의 모든 인스턴스에 대해 동일한 값을 가지며, 클래스 이름을 통해 접근할 수 있습니다.
예: class Example: shared_variable = 5

class MyClass:
    shared_variable = 10  # 클래스 변수

    def __init__(self, value):
        self.instance_variable = value  # 인스턴스 변수

# 클래스 변수는 클래스 이름을 통해 접근 가능
print(MyClass.shared_variable)  # 출력: 10

# 인스턴스 변수는 객체 인스턴스를 통해 접근
my_instance = MyClass(20)
print(my_instance.instance_variable)  # 출력: 20

이 예시에서 shared_variable은 모든 MyClass 인스턴스에 대해 공유되는 반면, instance_variable은 각 MyClass 인스턴스에 대해 고유합니다. 클래스 변수는 클래스의 모든 인스턴스에 공통적인 값을 저장하는 데 사용되며, 인스턴스 변수는 각 인스턴스의 개별적인 상태를 저장하는 데 사용됩니다.

 

features = models.vgg19(pretrained=True).features에서 features의 정보:
models.vgg19(pretrained=True)는 사전 훈련된 VGG-19 모델을 로드합니다. 여기서 pretrained=True는 모델이 이미 학습된 가중치를 가지고 로드되어야 함을 의미합니다.
.features는 VGG-19 모델의 "특징 추출" 부분을 나타냅니다. VGG-19 모델은 크게 두 부분으로 구성됩니다: 특징 추출 부분(features)과 분류 부분(classifier). features 부분은 여러 컨볼루션 레이어와 풀링 레이어로 구성되어 있으며, 이미지에서 중요한 특징을 추출하는 데 사용됩니다.
이 features 부분은 이미지의 스타일과 콘텐츠를 분석하는 데 사용되며, 스타일 전이나 지각적 손실 계산에 주로 활용됩니다. 각 레이어는 이미지의 다른 수준의 특징을 나타내며, 이를 통해 이미지의 깊은 특성을 이해할 수 있습니다.

 

import torch
import torch.nn as nn
import torchvision.models as models

class AdversarialLoss(nn.Module): # nn.Module을 상속받는다.
    r"""
    Adversarial loss
    https://arxiv.org/abs/1711.10337
    """

    def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
        r"""
        type = nsgan | lsgan | hinge
        """
        super(AdversarialLoss, self).__init__() # 상위 클래스 생성자 호출

        self.type = type
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))

        if type == 'nsgan':
            self.criterion = nn.BCELoss()

        elif type == 'lsgan':
            self.criterion = nn.MSELoss()

        elif type == 'hinge':
            self.criterion = nn.ReLU()

    def __call__(self, outputs, is_real, is_disc=None):
        if self.type == 'hinge': # loss = max{0, 1-(y' * y)}, 이 때, is_real이라면 y == 1 --> loss = max{0, 1-y'}
            if is_disc: # 호출이 판별자에 의한 것이라면 (로스를 최대화 하려고 한다.)
                if is_real: # 그리고 이미지가 실제라면
                    outputs = -outputs # 실제 이미지에 대한 출력을 음수로 반전한다.
                return self.criterion(1 + outputs).mean() # 판별자에 대한 hinge 손실을 계산합니다.
                # 1 + outputs는 실제 이미지에 대해 1보다 큰 값, 가짜 이미지에 대해 1보다 작은 값을 갖도록 만듭니다. 그리고 이 값의 평균을 반환합니다.
            else: # 호출이 생성자에 의한 것이라면 (로스를 최소화 하려고 한다.)
                return (-outputs).mean() # - 생성자에 대한 hinge 손실 계산입니다.

        else: # 손실 함수의 유형이 nsgan 또는 lsgan인 경우.
            labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
            # 이미지가 실제라면 labels = self.real_label
            # 이미지가 가짜라면 labels = self.fake_label
            loss = self.criterion(outputs, labels)
            return loss


class StyleLoss(nn.Module):
    r"""
    Perceptual loss, VGG-based
    https://arxiv.org/abs/1603.08155
    https://github.com/dxyang/StyleTransfer/blob/master/utils.py
    """

    def __init__(self):
        super(StyleLoss, self).__init__()
        self.add_module('vgg', VGG19()) # VGG19 모델을 클래스 하위 모듈로 추가. 이는 사전 훈련된 컨볼루션 신경망으로, 이미지의 스타일과 콘텐츠를 분석하는데 사용된다.
        self.criterion = torch.nn.L1Loss() # 예측값과 목표값 사이의 절대값의 평균 오차를 계산한다.

    def compute_gram(self, x): # gram matrix 계산
        b, ch, h, w = x.size()
        f = x.view(b, ch, w * h) # 피처맵을 2차원의 행렬로 변환
        f_T = f.transpose(1, 2) # 피처맵의 전치행렬을 계산
        G = f.bmm(f_T) / (h * w * ch) # Gram 행렬을 계산합니다. bmm은 배치 행렬 곱셈을 의미하며, 결과는 특징 맵의 크기로 정규화됩니다.

        return G

    def __call__(self, x, y):
        # Compute features
        x_vgg, y_vgg = self.vgg(x), self.vgg(y) # x는 입력 이미지, y는 스타일 참조 이미지.

        # Compute loss
        style_loss = 0.0
        style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
        style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
        style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
        style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))

        return style_loss # 입력 이미지가 참조 이미지의 스타일을 얼마나 잘 모방하는가를 측정하는 로스



class PerceptualLoss(nn.Module):
    r"""
    Perceptual loss, VGG-based
    https://arxiv.org/abs/1603.08155
    https://github.com/dxyang/StyleTransfer/blob/master/utils.py
    """

    def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
        # weights는 VGG 네트워크의 각 레이어에 적용될 가중치를 나타냅니다.
        super(PerceptualLoss, self).__init__()
        self.add_module('vgg', VGG19()) # VGG19 모델을 클래스 하위 모듈로 추가. 이는 사전 훈련된 컨볼루션 신경망으로, 이미지의 스타일과 콘텐츠를 분석하는데 사용된다.
        self.criterion = torch.nn.L1Loss()
        self.weights = weights # VGG 네트워크의 각 레이어에 적용될 가중치를 인스턴스 변수에 저장합니다.

    def __call__(self, x, y):
        # Compute features
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)

        content_loss = 0.0
        content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
        content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
        content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
        content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
        content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])


        return content_loss



class VGG19(torch.nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        features = models.vgg19(pretrained=True).features
        self.relu1_1 = torch.nn.Sequential()
        self.relu1_2 = torch.nn.Sequential()

        self.relu2_1 = torch.nn.Sequential()
        self.relu2_2 = torch.nn.Sequential()

        self.relu3_1 = torch.nn.Sequential()
        self.relu3_2 = torch.nn.Sequential()
        self.relu3_3 = torch.nn.Sequential()
        self.relu3_4 = torch.nn.Sequential()

        self.relu4_1 = torch.nn.Sequential()
        self.relu4_2 = torch.nn.Sequential()
        self.relu4_3 = torch.nn.Sequential()
        self.relu4_4 = torch.nn.Sequential()

        self.relu5_1 = torch.nn.Sequential()
        self.relu5_2 = torch.nn.Sequential()
        self.relu5_3 = torch.nn.Sequential()
        self.relu5_4 = torch.nn.Sequential()

        for x in range(2):
            self.relu1_1.add_module(str(x), features[x])

        for x in range(2, 4):
            self.relu1_2.add_module(str(x), features[x])

        for x in range(4, 7):
            self.relu2_1.add_module(str(x), features[x])

        for x in range(7, 9):
            self.relu2_2.add_module(str(x), features[x])

        for x in range(9, 12):
            self.relu3_1.add_module(str(x), features[x])

        for x in range(12, 14):
            self.relu3_2.add_module(str(x), features[x])

        for x in range(14, 16):
            self.relu3_3.add_module(str(x), features[x])

        for x in range(16, 18):
            self.relu3_4.add_module(str(x), features[x])

        for x in range(18, 21):
            self.relu4_1.add_module(str(x), features[x])

        for x in range(21, 23):
            self.relu4_2.add_module(str(x), features[x])

        for x in range(23, 25):
            self.relu4_3.add_module(str(x), features[x])

        for x in range(25, 27):
            self.relu4_4.add_module(str(x), features[x])

        for x in range(27, 30):
            self.relu5_1.add_module(str(x), features[x])

        for x in range(30, 32):
            self.relu5_2.add_module(str(x), features[x])

        for x in range(32, 34):
            self.relu5_3.add_module(str(x), features[x])

        for x in range(34, 36):
            self.relu5_4.add_module(str(x), features[x])

        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        relu1_1 = self.relu1_1(x)
        relu1_2 = self.relu1_2(relu1_1)

        relu2_1 = self.relu2_1(relu1_2)
        relu2_2 = self.relu2_2(relu2_1)

        relu3_1 = self.relu3_1(relu2_2)
        relu3_2 = self.relu3_2(relu3_1)
        relu3_3 = self.relu3_3(relu3_2)
        relu3_4 = self.relu3_4(relu3_3)

        relu4_1 = self.relu4_1(relu3_4)
        relu4_2 = self.relu4_2(relu4_1)
        relu4_3 = self.relu4_3(relu4_2)
        relu4_4 = self.relu4_4(relu4_3)

        relu5_1 = self.relu5_1(relu4_4)
        relu5_2 = self.relu5_2(relu5_1)
        relu5_3 = self.relu5_3(relu5_2)
        relu5_4 = self.relu5_4(relu5_3)

        out = {
            'relu1_1': relu1_1,
            'relu1_2': relu1_2,

            'relu2_1': relu2_1,
            'relu2_2': relu2_2,

            'relu3_1': relu3_1,
            'relu3_2': relu3_2,
            'relu3_3': relu3_3,
            'relu3_4': relu3_4,

            'relu4_1': relu4_1,
            'relu4_2': relu4_2,
            'relu4_3': relu4_3,
            'relu4_4': relu4_4,

            'relu5_1': relu5_1,
            'relu5_2': relu5_2,
            'relu5_3': relu5_3,
            'relu5_4': relu5_4,
        }
        return out

'코드 분석 > Edge-Connect' 카테고리의 다른 글

Edge-Connect: utils.py  (1) 2024.01.25
Edge-Connect: main.py  (0) 2024.01.25
Edge-Connect: metrics.py  (0) 2024.01.25
Edge-Connect: config.py  (1) 2024.01.24
Edge-Connect: networks.py  (1) 2024.01.24