코드 분석/Edge-Connect

Edge-Connect: models.py

상솜공방 2024. 1. 27. 12:01

코드를 이해하기 위한 사전 지식

더보기

nn.DataParallel() 함수
nn.DataParallel은 PyTorch에서 제공하는 함수로, 모델을 여러 GPU에 분산시켜 병렬로 학습할 수 있게 해줍니다. 이 함수의 주요 특징은 다음과 같습니다:

모델의 복사본을 여러 GPU에 분산시킵니다.
각 GPU에서는 데이터의 서브셋(subset)에 대해 연산을 수행합니다.
모든 GPU에서의 연산 결과는 자동으로 합쳐져 최종 결과를 생성합니다.
이를 통해 대규모 데이터셋과 복잡한 모델을 더 빠르게 학습할 수 있습니다.

 

self.add_module() 함수와 모델 초기화 방식
self.add_module('generator', generator)는 nn.Module 클래스의 메서드로, 모델의 서브모듈을 동적으로 추가하는 데 사용됩니다. 이 메서드를 사용하는 이유는 다음과 같습니다:

유연성: add_module을 사용하면 모델의 구성 요소를 더 유연하게 관리할 수 있습니다. 예를 들어, 모델의 특정 부분을 조건부로 추가하거나 변경할 수 있습니다.
동적 구성: 모델의 구성 요소를 실행 시간(runtime)에 결정하거나 변경할 수 있습니다. 이는 복잡한 모델 구조에서 특히 유용합니다.


self.generator = EdgeGenerator(use_spectral_norm=True)와 같이 직접 할당하는 방식도 가능하지만, add_module을 사용하면 모듈의 이름을 명시적으로 지정할 수 있고, 모듈 관리가 더 명확해집니다. nn.Module 클래스에 존재하는 add_module 메서드는 모듈의 이름과 인스턴스를 매핑하여 모듈을 관리하는 데 도움을 줍니다.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from .networks import InpaintGenerator, EdgeGenerator, Discriminator
from .loss import AdversarialLoss, PerceptualLoss, StyleLoss


class BaseModel(nn.Module):
    def __init__(self, name, config): # name: 모델 이름, config: 설정
        super(BaseModel, self).__init__()

        self.name = name
        self.config = config
        self.iteration = 0

        # 생성자와 판별자의 가중치를 저장할 경로 설정
        self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth')
        self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')

    # 모델의 가중치를 로드하는 함수
    def load(self):
        if os.path.exists(self.gen_weights_path): # 생성자의 웨이트 파일이 존재한다면,
            print('Loading %s generator...' % self.name) # Loading model name generator

            if torch.cuda.is_available(): # CUDA를 사용 가능하다면
                data = torch.load(self.gen_weights_path) # 
            else:
                data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)

            # 베이스 모델에선 self.generator를 정의하지 않고 바로 .load_state_dict() 함수를 사용한다.
            # 아마 이를 상속받는 모델에서는 사전에 정의를 할 수 있다. --> 79줄을 참조.
            self.generator.load_state_dict(data['generator'])
            self.iteration = data['iteration']

        # load discriminator only when training
        if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
            print('Loading %s discriminator...' % self.name)

            if torch.cuda.is_available():
                data = torch.load(self.dis_weights_path)
            else:
                data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)

            self.discriminator.load_state_dict(data['discriminator'])

    # 모델의 상태를 저장하는 함수
    def save(self):
        print('\nsaving %s...\n' % self.name)
        torch.save({
            'iteration': self.iteration, # 이터레이션 저장
            'generator': self.generator.state_dict() # 생성자 모델의 상태 딕셔너리를 저장. 이는 모델 각 레이어의 가중치와 바이어스 등의 파라미터를 포함한다.
        }, self.gen_weights_path)

        torch.save({
            'discriminator': self.discriminator.state_dict()
        }, self.dis_weights_path)


class EdgeModel(BaseModel):
    def __init__(self, config):
        super(EdgeModel, self).__init__('EdgeModel', config)

        # 생성자, 판별자 초기화
        # generator input: [grayscale(1) + edge(1) + mask(1)]
        # discriminator input: (grayscale(1) + edge(1))
        generator = EdgeGenerator(use_spectral_norm=True)
        discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge') # 설정된 GAN 손실 함수 유형이 'hinge'가 아니면 True

        # GPU 분산 처리
        if len(config.GPU) > 1: # GPU가 여러개인 경우 
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator, config.GPU)
        
        # 로스 초기화
        l1_loss = nn.L1Loss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        # 모듈 추가 함수를 통해 인스턴스 변수로 변환
        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(
            params=generator.parameters(),
            lr=float(config.LR),
            betas=(config.BETA1, config.BETA2)
        )

        self.dis_optimizer = optim.Adam(
            params=discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(config.BETA1, config.BETA2)
        )

    # 이미지, 에지, 마스크를 입력으로 받아 생성자(generator)와 판별자(discriminator)의 손실을 계산하고, 최적화 과정을 수행하는 함수
    def process(self, images, edges, masks):
        self.iteration += 1


        # zero optimizers
        # zero_grad()는 이전의 기울기를 0으로 초기화하여 새로운 최적화를 하는 것.
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()


        # process outputs
        outputs = self(images, edges, masks) # 모델에 이미지, 에지, 마스크를 입력하여 출력을 생성
        # 여기서 self는 EdgeModel 클래스의 인스턴스를 의미한다. 이 구문은 Python의 클래스 내에서 해당 인스턴스의 '__call__' 메서드를 호출하는 것과 동일하다.
        # nn.Module의 클래스 인스턴스는 '__call__' 메서드를 통해 자신의 forward 메서드를 호출한다.
        # 따라서 해당 코드는 self.forward(image, edges, masks)를 호출하는 것과 동일하다.
        gen_loss = 0
        dis_loss = 0


        # discriminator loss
        # 판별자에 입력할 데이터부터 
        dis_input_real = torch.cat((images, edges), dim=1) # 실제 이미지와 실제 에지를 채널 차원(dim=1)을 따라 연결(concatenate)하여 판별자의 입력으로 사용
        dis_input_fake = torch.cat((images, outputs.detach()), dim=1) # 실제 이미지와 생성자가 생성한 가짜 에지를 연결하여 판별자의 또 다른 입력으로 사용
        # detach()는 생성된 에지가 그래디언트 계산에서 제외되도록 함

        # 판별자에 실제 입력을 넣어 결과(dis_real)와 중간 특징(dis_real_feat)을 얻음.
        dis_real, dis_real_feat = self.discriminator(dis_input_real)        # in: (grayscale(1) + edge(1))
        # 판별자에 가짜 입력을 넣어 결과(dis_fake)와 중간 특징(dis_fake_feat)을 얻음.
        dis_fake, dis_fake_feat = self.discriminator(dis_input_fake)        # in: (grayscale(1) + edge(1))
        
        # 실제 입력에 대한 판별자의 손실을 계산합니다. 여기서 True는 실제 데이터임을 나타냄.
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        # 가짜 입력에 대한 판별자의 손실을 계산합니다. 여기서 False는 가짜 데이터임을 나타냄.
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        # 실제와 가짜 입력에 대한 손실의 평균을 취하여 판별자의 총 손실을 업데이트
        dis_loss = dis_loss + (dis_real_loss + dis_fake_loss) / 2


        # generator adversarial loss
        # 실제 이미지와 생성된 에지를 연결하여 생성자의 입력으로 사용.
        gen_input_fake = torch.cat((images, outputs), dim=1)
        # 생성된 입력을 판별자에 넣어 결과(gen_fake)와 중간 특징(gen_fake_feat)을 얻음.
        gen_fake, gen_fake_feat = self.discriminator(gen_input_fake)        # in: (grayscale(1) + edge(1))
        # 생성자의 적대적 손실을 계산합니다. 여기서 True는 생성자가 판별자를 속이려는 목표.
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
        # 생성자 손실을 더하여 업데이트.
        gen_loss = gen_loss + gen_gan_loss


        # generator feature matching loss
        gen_fm_loss = 0 # 피처 매칭 로스 초기화
        for i in range(len(dis_real_feat)): # 판별자의 실제 및 가짜 입력에 대한 피처들 간의 L1 손실을 계산
            gen_fm_loss = gen_fm_loss + self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
        gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT # 설정된 가중치를 적용
        gen_loss = gen_loss + gen_fm_loss # 생성자의 총 손실에 특징 매칭 손실을 추가


        # create logs
        logs = [
            ("l_d1", dis_loss.item()),
            ("l_g1", gen_gan_loss.item()),
            ("l_fm", gen_fm_loss.item()),
        ]

        return outputs, gen_loss, dis_loss, logs

    # 모델 순전파 함수 오버라이딩
    def forward(self, images, edges, masks): # 이미지, 에지, 마스크를 입력으로 받음.
        edges_masked = (edges * (1 - masks)) # 에지에 마스크를 적용. 마스크된 영역은 0이 되고, 나머지 영역은 원래의 에지 값이 유지
        images_masked = (images * (1 - masks)) + masks # 이미지에 마스크를 적용. 마스크된 영역은 1이 되고, 나머지 영역은 원래의 이미지 값이 유지
        inputs = torch.cat((images_masked, edges_masked, masks), dim=1) # 마스크된 이미지, 마스크된 에지, 마스크를 채널 차원(dim=1)을 따라 연결하여 생성자의 입력으로 사용
        outputs = self.generator(inputs)                                    # in: [grayscale(1) + edge(1) + mask(1)]
        
        # 모델의 순전파 함수에서 outputs를 self.discriminator에 넣는 과정이 없는 이유는 이 과정이 상단의 process 메서드에서 수행되기 때문이다.
        return outputs

    # 모델 역전파 함수
    def backward(self, gen_loss=None, dis_loss=None):
        if dis_loss is not None: # 판별자의 손실이 제공된 경우, 해당 손실에 대한 역전파를 수행
            dis_loss.backward() # 판별자의 손실에 대해 역전파를 수행합니다. 이는 판별자의 가중치에 대한 그래디언트를 계산
            
        if gen_loss is not None: # 생성자의 손실이 제공된 경우, 해당 손실에 대한 역전파를 수행
            gen_loss.backward() # 생성자의 손실에 대해 역전파를 수행합니다. 이는 생성자의 가중치에 대한 그래디언트를 계산

        self.dis_optimizer.step() # 판별자의 최적화기를 한 단계 업데이트합니다. 이는 계산된 그래디언트를 사용하여 판별자의 가중치를 조정
        self.gen_optimizer.step() # 생성자의 최적화기를 한 단계 업데이트합니다. 이는 계산된 그래디언트를 사용하여 생성자의 가중치를 조정


class InpaintingModel(BaseModel):
    def __init__(self, config): # config 정보를 받아 인스턴스 초기화.
        super(InpaintingModel, self).__init__('InpaintingModel', config) # 부모 클래스 생성자를 호출하여 인스턴스 초기화.

        # 생성자, 판별자 초기화
        # generator input: [rgb(3) + edge(1)]
        # discriminator input: [rgb(3)]
        generator = InpaintGenerator()
        discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')

        # GPU 분산 처리 옵션
        if len(config.GPU) > 1:
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator , config.GPU)

        # 로스 초기화
        l1_loss = nn.L1Loss()
        perceptual_loss = PerceptualLoss()
        style_loss = StyleLoss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        # 모듈 추가
        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('perceptual_loss', perceptual_loss)
        self.add_module('style_loss', style_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        # 옵티마이저 설정
        self.gen_optimizer = optim.Adam(
            params=generator.parameters(),
            lr=float(config.LR),
            betas=(config.BETA1, config.BETA2)
        )

        self.dis_optimizer = optim.Adam(
            params=discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(config.BETA1, config.BETA2)
        )

    def process(self, images, edges, masks):
        self.iteration += 1

        # zero optimizers
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()


        # process outputs
        outputs = self(images, edges, masks) # outputs = self.forward(images, edges, masks)와 동일한 명령어.
        gen_loss = 0
        dis_loss = 0


        # discriminator loss
        dis_input_real = images
        dis_input_fake = outputs.detach()
        dis_real, _ = self.discriminator(dis_input_real)                    # in: [rgb(3)]
        dis_fake, _ = self.discriminator(dis_input_fake)                    # in: [rgb(3)]
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2


        # generator adversarial loss
        gen_input_fake = outputs
        gen_fake, _ = self.discriminator(gen_input_fake)                    # in: [rgb(3)]
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
        gen_loss = gen_loss + gen_gan_loss


        # generator l1 loss
        gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
        gen_loss = gen_loss + gen_l1_loss


        # generator perceptual loss
        gen_content_loss = self.perceptual_loss(outputs, images)
        gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
        gen_loss = gen_loss + gen_content_loss


        # generator style loss
        gen_style_loss = self.style_loss(outputs * masks, images * masks)
        gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
        gen_loss = gen_loss + gen_style_loss


        # create logs
        logs = [
            ("l_d2", dis_loss.item()),
            ("l_g2", gen_gan_loss.item()),
            ("l_l1", gen_l1_loss.item()),
            ("l_per", gen_content_loss.item()),
            ("l_sty", gen_style_loss.item()),
        ]

        return outputs, gen_loss, dis_loss, logs

    def forward(self, images, edges, masks):
        images_masked = (images * (1 - masks).float()) + masks
        inputs = torch.cat((images_masked, edges), dim=1)
        outputs = self.generator(inputs)                                    # in: [rgb(3) + edge(1)]
        return outputs

    def backward(self, gen_loss=None, dis_loss=None):
        dis_loss.backward()
        self.dis_optimizer.step()

        gen_loss.backward()
        self.gen_optimizer.step()

'코드 분석 > Edge-Connect' 카테고리의 다른 글

Edge-Connect: 모듈 분석  (0) 2024.01.30
Edge-Connect: edge_connect.py  (2) 2024.01.27
Edge-Connect: dataset.py  (0) 2024.01.25
Edge-Connect: utils.py  (1) 2024.01.25
Edge-Connect: main.py  (0) 2024.01.25