Edge-Connect: models.py
코드를 이해하기 위한 사전 지식
nn.DataParallel() 함수
nn.DataParallel은 PyTorch에서 제공하는 함수로, 모델을 여러 GPU에 분산시켜 병렬로 학습할 수 있게 해줍니다. 이 함수의 주요 특징은 다음과 같습니다:
모델의 복사본을 여러 GPU에 분산시킵니다.
각 GPU에서는 데이터의 서브셋(subset)에 대해 연산을 수행합니다.
모든 GPU에서의 연산 결과는 자동으로 합쳐져 최종 결과를 생성합니다.
이를 통해 대규모 데이터셋과 복잡한 모델을 더 빠르게 학습할 수 있습니다.
self.add_module() 함수와 모델 초기화 방식
self.add_module('generator', generator)는 nn.Module 클래스의 메서드로, 모델의 서브모듈을 동적으로 추가하는 데 사용됩니다. 이 메서드를 사용하는 이유는 다음과 같습니다:
유연성: add_module을 사용하면 모델의 구성 요소를 더 유연하게 관리할 수 있습니다. 예를 들어, 모델의 특정 부분을 조건부로 추가하거나 변경할 수 있습니다.
동적 구성: 모델의 구성 요소를 실행 시간(runtime)에 결정하거나 변경할 수 있습니다. 이는 복잡한 모델 구조에서 특히 유용합니다.
self.generator = EdgeGenerator(use_spectral_norm=True)와 같이 직접 할당하는 방식도 가능하지만, add_module을 사용하면 모듈의 이름을 명시적으로 지정할 수 있고, 모듈 관리가 더 명확해집니다. nn.Module 클래스에 존재하는 add_module 메서드는 모듈의 이름과 인스턴스를 매핑하여 모듈을 관리하는 데 도움을 줍니다.
import os
import torch
import torch.nn as nn
import torch.optim as optim
from .networks import InpaintGenerator, EdgeGenerator, Discriminator
from .loss import AdversarialLoss, PerceptualLoss, StyleLoss
class BaseModel(nn.Module):
def __init__(self, name, config): # name: 모델 이름, config: 설정
super(BaseModel, self).__init__()
self.name = name
self.config = config
self.iteration = 0
# 생성자와 판별자의 가중치를 저장할 경로 설정
self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth')
self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')
# 모델의 가중치를 로드하는 함수
def load(self):
if os.path.exists(self.gen_weights_path): # 생성자의 웨이트 파일이 존재한다면,
print('Loading %s generator...' % self.name) # Loading model name generator
if torch.cuda.is_available(): # CUDA를 사용 가능하다면
data = torch.load(self.gen_weights_path) #
else:
data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)
# 베이스 모델에선 self.generator를 정의하지 않고 바로 .load_state_dict() 함수를 사용한다.
# 아마 이를 상속받는 모델에서는 사전에 정의를 할 수 있다. --> 79줄을 참조.
self.generator.load_state_dict(data['generator'])
self.iteration = data['iteration']
# load discriminator only when training
if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
print('Loading %s discriminator...' % self.name)
if torch.cuda.is_available():
data = torch.load(self.dis_weights_path)
else:
data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)
self.discriminator.load_state_dict(data['discriminator'])
# 모델의 상태를 저장하는 함수
def save(self):
print('\nsaving %s...\n' % self.name)
torch.save({
'iteration': self.iteration, # 이터레이션 저장
'generator': self.generator.state_dict() # 생성자 모델의 상태 딕셔너리를 저장. 이는 모델 각 레이어의 가중치와 바이어스 등의 파라미터를 포함한다.
}, self.gen_weights_path)
torch.save({
'discriminator': self.discriminator.state_dict()
}, self.dis_weights_path)
class EdgeModel(BaseModel):
def __init__(self, config):
super(EdgeModel, self).__init__('EdgeModel', config)
# 생성자, 판별자 초기화
# generator input: [grayscale(1) + edge(1) + mask(1)]
# discriminator input: (grayscale(1) + edge(1))
generator = EdgeGenerator(use_spectral_norm=True)
discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge') # 설정된 GAN 손실 함수 유형이 'hinge'가 아니면 True
# GPU 분산 처리
if len(config.GPU) > 1: # GPU가 여러개인 경우
generator = nn.DataParallel(generator, config.GPU)
discriminator = nn.DataParallel(discriminator, config.GPU)
# 로스 초기화
l1_loss = nn.L1Loss()
adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
# 모듈 추가 함수를 통해 인스턴스 변수로 변환
self.add_module('generator', generator)
self.add_module('discriminator', discriminator)
self.add_module('l1_loss', l1_loss)
self.add_module('adversarial_loss', adversarial_loss)
self.gen_optimizer = optim.Adam(
params=generator.parameters(),
lr=float(config.LR),
betas=(config.BETA1, config.BETA2)
)
self.dis_optimizer = optim.Adam(
params=discriminator.parameters(),
lr=float(config.LR) * float(config.D2G_LR),
betas=(config.BETA1, config.BETA2)
)
# 이미지, 에지, 마스크를 입력으로 받아 생성자(generator)와 판별자(discriminator)의 손실을 계산하고, 최적화 과정을 수행하는 함수
def process(self, images, edges, masks):
self.iteration += 1
# zero optimizers
# zero_grad()는 이전의 기울기를 0으로 초기화하여 새로운 최적화를 하는 것.
self.gen_optimizer.zero_grad()
self.dis_optimizer.zero_grad()
# process outputs
outputs = self(images, edges, masks) # 모델에 이미지, 에지, 마스크를 입력하여 출력을 생성
# 여기서 self는 EdgeModel 클래스의 인스턴스를 의미한다. 이 구문은 Python의 클래스 내에서 해당 인스턴스의 '__call__' 메서드를 호출하는 것과 동일하다.
# nn.Module의 클래스 인스턴스는 '__call__' 메서드를 통해 자신의 forward 메서드를 호출한다.
# 따라서 해당 코드는 self.forward(image, edges, masks)를 호출하는 것과 동일하다.
gen_loss = 0
dis_loss = 0
# discriminator loss
# 판별자에 입력할 데이터부터
dis_input_real = torch.cat((images, edges), dim=1) # 실제 이미지와 실제 에지를 채널 차원(dim=1)을 따라 연결(concatenate)하여 판별자의 입력으로 사용
dis_input_fake = torch.cat((images, outputs.detach()), dim=1) # 실제 이미지와 생성자가 생성한 가짜 에지를 연결하여 판별자의 또 다른 입력으로 사용
# detach()는 생성된 에지가 그래디언트 계산에서 제외되도록 함
# 판별자에 실제 입력을 넣어 결과(dis_real)와 중간 특징(dis_real_feat)을 얻음.
dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1))
# 판별자에 가짜 입력을 넣어 결과(dis_fake)와 중간 특징(dis_fake_feat)을 얻음.
dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1))
# 실제 입력에 대한 판별자의 손실을 계산합니다. 여기서 True는 실제 데이터임을 나타냄.
dis_real_loss = self.adversarial_loss(dis_real, True, True)
# 가짜 입력에 대한 판별자의 손실을 계산합니다. 여기서 False는 가짜 데이터임을 나타냄.
dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
# 실제와 가짜 입력에 대한 손실의 평균을 취하여 판별자의 총 손실을 업데이트
dis_loss = dis_loss + (dis_real_loss + dis_fake_loss) / 2
# generator adversarial loss
# 실제 이미지와 생성된 에지를 연결하여 생성자의 입력으로 사용.
gen_input_fake = torch.cat((images, outputs), dim=1)
# 생성된 입력을 판별자에 넣어 결과(gen_fake)와 중간 특징(gen_fake_feat)을 얻음.
gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1))
# 생성자의 적대적 손실을 계산합니다. 여기서 True는 생성자가 판별자를 속이려는 목표.
gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
# 생성자 손실을 더하여 업데이트.
gen_loss = gen_loss + gen_gan_loss
# generator feature matching loss
gen_fm_loss = 0 # 피처 매칭 로스 초기화
for i in range(len(dis_real_feat)): # 판별자의 실제 및 가짜 입력에 대한 피처들 간의 L1 손실을 계산
gen_fm_loss = gen_fm_loss + self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT # 설정된 가중치를 적용
gen_loss = gen_loss + gen_fm_loss # 생성자의 총 손실에 특징 매칭 손실을 추가
# create logs
logs = [
("l_d1", dis_loss.item()),
("l_g1", gen_gan_loss.item()),
("l_fm", gen_fm_loss.item()),
]
return outputs, gen_loss, dis_loss, logs
# 모델 순전파 함수 오버라이딩
def forward(self, images, edges, masks): # 이미지, 에지, 마스크를 입력으로 받음.
edges_masked = (edges * (1 - masks)) # 에지에 마스크를 적용. 마스크된 영역은 0이 되고, 나머지 영역은 원래의 에지 값이 유지
images_masked = (images * (1 - masks)) + masks # 이미지에 마스크를 적용. 마스크된 영역은 1이 되고, 나머지 영역은 원래의 이미지 값이 유지
inputs = torch.cat((images_masked, edges_masked, masks), dim=1) # 마스크된 이미지, 마스크된 에지, 마스크를 채널 차원(dim=1)을 따라 연결하여 생성자의 입력으로 사용
outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)]
# 모델의 순전파 함수에서 outputs를 self.discriminator에 넣는 과정이 없는 이유는 이 과정이 상단의 process 메서드에서 수행되기 때문이다.
return outputs
# 모델 역전파 함수
def backward(self, gen_loss=None, dis_loss=None):
if dis_loss is not None: # 판별자의 손실이 제공된 경우, 해당 손실에 대한 역전파를 수행
dis_loss.backward() # 판별자의 손실에 대해 역전파를 수행합니다. 이는 판별자의 가중치에 대한 그래디언트를 계산
if gen_loss is not None: # 생성자의 손실이 제공된 경우, 해당 손실에 대한 역전파를 수행
gen_loss.backward() # 생성자의 손실에 대해 역전파를 수행합니다. 이는 생성자의 가중치에 대한 그래디언트를 계산
self.dis_optimizer.step() # 판별자의 최적화기를 한 단계 업데이트합니다. 이는 계산된 그래디언트를 사용하여 판별자의 가중치를 조정
self.gen_optimizer.step() # 생성자의 최적화기를 한 단계 업데이트합니다. 이는 계산된 그래디언트를 사용하여 생성자의 가중치를 조정
class InpaintingModel(BaseModel):
def __init__(self, config): # config 정보를 받아 인스턴스 초기화.
super(InpaintingModel, self).__init__('InpaintingModel', config) # 부모 클래스 생성자를 호출하여 인스턴스 초기화.
# 생성자, 판별자 초기화
# generator input: [rgb(3) + edge(1)]
# discriminator input: [rgb(3)]
generator = InpaintGenerator()
discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')
# GPU 분산 처리 옵션
if len(config.GPU) > 1:
generator = nn.DataParallel(generator, config.GPU)
discriminator = nn.DataParallel(discriminator , config.GPU)
# 로스 초기화
l1_loss = nn.L1Loss()
perceptual_loss = PerceptualLoss()
style_loss = StyleLoss()
adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
# 모듈 추가
self.add_module('generator', generator)
self.add_module('discriminator', discriminator)
self.add_module('l1_loss', l1_loss)
self.add_module('perceptual_loss', perceptual_loss)
self.add_module('style_loss', style_loss)
self.add_module('adversarial_loss', adversarial_loss)
# 옵티마이저 설정
self.gen_optimizer = optim.Adam(
params=generator.parameters(),
lr=float(config.LR),
betas=(config.BETA1, config.BETA2)
)
self.dis_optimizer = optim.Adam(
params=discriminator.parameters(),
lr=float(config.LR) * float(config.D2G_LR),
betas=(config.BETA1, config.BETA2)
)
def process(self, images, edges, masks):
self.iteration += 1
# zero optimizers
self.gen_optimizer.zero_grad()
self.dis_optimizer.zero_grad()
# process outputs
outputs = self(images, edges, masks) # outputs = self.forward(images, edges, masks)와 동일한 명령어.
gen_loss = 0
dis_loss = 0
# discriminator loss
dis_input_real = images
dis_input_fake = outputs.detach()
dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)]
dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)]
dis_real_loss = self.adversarial_loss(dis_real, True, True)
dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
dis_loss += (dis_real_loss + dis_fake_loss) / 2
# generator adversarial loss
gen_input_fake = outputs
gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)]
gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
gen_loss = gen_loss + gen_gan_loss
# generator l1 loss
gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
gen_loss = gen_loss + gen_l1_loss
# generator perceptual loss
gen_content_loss = self.perceptual_loss(outputs, images)
gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
gen_loss = gen_loss + gen_content_loss
# generator style loss
gen_style_loss = self.style_loss(outputs * masks, images * masks)
gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
gen_loss = gen_loss + gen_style_loss
# create logs
logs = [
("l_d2", dis_loss.item()),
("l_g2", gen_gan_loss.item()),
("l_l1", gen_l1_loss.item()),
("l_per", gen_content_loss.item()),
("l_sty", gen_style_loss.item()),
]
return outputs, gen_loss, dis_loss, logs
def forward(self, images, edges, masks):
images_masked = (images * (1 - masks).float()) + masks
inputs = torch.cat((images_masked, edges), dim=1)
outputs = self.generator(inputs) # in: [rgb(3) + edge(1)]
return outputs
def backward(self, gen_loss=None, dis_loss=None):
dis_loss.backward()
self.dis_optimizer.step()
gen_loss.backward()
self.gen_optimizer.step()