import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from .generator import EdgeGenerator, ImageGenerator
from .edge_connect_discriminator import Discriminator
from .loss import AdversarialLoss, PerceptualLoss, StyleLoss
class BaseModel(nn.Module):
def __init__(self, name, args):
super(BaseModel, self).__init__()
self.name = name
self.args = args
self.iteration = 0
self.gen_weights_path = os.path.join(args.checkpoint_dir, name + '_gen.pth') # ./data/output/~_gen.pth
self.dis_weights_path = os.path.join(args.checkpoint_dir, name + '_dis.pth') # ./data/output/~_dis.pth
def load(self):
if os.path.exists(self.gen_weights_path):
if torch.cuda.is_available():
data = torch.load(self.gen_weights_path)
else:
data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)
self.generator.load_state_dict(data['generator']) # 베이스 모델에선 self.generator를 선언하지 않으나, 상속받는 하단의 클래스에서 생성된다.
self.iteration = data['iteration']
print("{} generator successfully loaded".format(self.name))
if os.path.exists(self.dis_weights_path) and self.args.mode == 1:
if torch.cuda.is_available():
data = torch.load(self.dis_weights_path)
else:
data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)
self.discriminator.load_state_dict(data['discriminator']) # 베이스 모델에선 self.discriminator를 선언하지 않으나, 상속받는 하단의 클래스에서 생성된다.
print("{} discriminator successfully loaded".format(self.name))
def save(self):
print('Saving {}...\n'.format(self.name))
torch.save({'iteration': self.iteration, 'generator': self.generator.state_dict()}, self.gen_weights_path)
torch.save({'discriminator': self.discriminator.state_dict()}, self.dis_weights_path)
def show_batch_images(self, tensors, tensor_names, figsize_per_image=(3, 3)):
B, C, H, W = tensors[0].shape
num_tensors = len(tensors)
# 전체 도표의 크기 계산
fig_width = figsize_per_image[0] * num_tensors
fig_height = figsize_per_image[1] * B
# 전체 도표 생성
fig, axs = plt.subplots(B, num_tensors, figsize=(fig_width, fig_height), squeeze=False)
for i in range(B):
for j, (tensor, name) in enumerate(zip(tensors, tensor_names)):
ax = axs[i, j]
img = tensor[i].permute(1, 2, 0).cpu().detach().numpy() if C == 3 else tensor[i].squeeze().cpu().detach().numpy()
ax.imshow(img, cmap='gray' if tensor.shape[1] == 1 else None)
# 첫 번째 열의 각 행에만 텐서 이름을 제목으로 설정
if j == 0:
ax.set_ylabel(f"Batch {i+1}", rotation=0, size='large', labelpad=20)
if i == 0:
ax.set_title(name)
ax.axis('off')
plt.tight_layout()
plt.show()
class EdgeModel(BaseModel):
def __init__(self, args):
# args 인수를 넘겨주며 부모 클래스 초기화
super(EdgeModel, self).__init__('EdgeModel', args)
# args에 있는 정보를 인스턴스 변수로 저장
self.img_channel = args.img_channel
self.mask_channel = args.mask_channel
self.batch_size = args.batch_size
self.input_size = args.input_size
self.patch_size = args.patch_size
self.latent_vec_dim = args.latent_vec_dim
self.drop_rate = args.drop_rate
self.num_heads = args.num_heads
self.num_layers = args.num_layers
self.num_patches = int((self.input_size / self.patch_size) * (self.input_size / self.patch_size))
self.patch_vec_size = self.patch_size * self.patch_size * self.img_channel
self.mlp_hidden_dim = int(self.latent_vec_dim/2)
self.device = torch.device(args.device)
# 생성기, 판별기, 오차 함수 초기화
# generator input: [grayscale(1) + edge(1), mask(1)]
# discriminator input: [grayscale(1) + edge(1)]
self.generator = EdgeGenerator(self.batch_size, self.img_channel, self.input_size, self.input_size, self.patch_size, self.patch_vec_size,self.num_patches,
self.latent_vec_dim, self.num_heads, self.mlp_hidden_dim, self.drop_rate, self.num_heads, self.device).to(self.device)
self.discriminator = Discriminator(in_channels=2, use_sigmoid=args.gan_loss != 'hinge').to(self.device)
self.l1_loss = nn.L1Loss()
self.adversarial_loss = AdversarialLoss(type=args.gan_loss)
# 옵티마이저 설정 초기화
self.gen_optimizer = optim.Adam(params=self.generator.parameters(), lr=float(args.lr), betas=(args.beta1, args.beta2))
self.dis_optimizer = optim.Adam(params=self.discriminator.parameters(),lr=float(args.lr) * float(args.d2g_lr),betas=(args.beta1, args.beta2))
def forward(self, images, edges, masks):
images_masked = (images * (1 - masks)) # 원본 코드: images_masked = (images * (1 - masks)) + masks
edges_masked = (edges * (1 - masks))
input_images = torch.cat((images_masked, edges_masked), dim = 1)
outputs, outputs_list, mask_list, attn_list = self.generator(input_images, masks)
# self.show_batch_images([images_masked, edges_masked, outputs], ['images_masked', 'edges_masked', 'outputs']) <== 마스크 된 이미지와 엣지, 그리고 아웃풋 데이터를 가시화
return outputs
def backward(self, gen_loss = None, dis_loss = None):
if dis_loss is not None:
dis_loss.backward()
if gen_loss is not None:
gen_loss.backward()
self.dis_optimizer.step()
self.gen_optimizer.step()
def process(self, images, edges, masks):
self.iteration += 1
gen_loss = 0
dis_loss = 0
self.gen_optimizer.zero_grad()
self.dis_optimizer.zero_grad()
# self.show_batch_images([images, edges, masks], ['original_images', 'original_edges', 'masks']) <== 인풋 데이터 가시화
outputs = self(images, edges, masks) # 인스턴스를 호출하면 인스턴스.forward() 함수가 자동으로 동작
# discriminator loss
dis_input_real = torch.cat((images, edges), dim=1)
dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1))
dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1))
dis_real_loss = self.adversarial_loss(dis_real, True, True)
dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
dis_loss += (dis_real_loss + dis_fake_loss) / 2
# generator adversarial loss
gen_input_fake = torch.cat((images, outputs), dim=1)
gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1))
gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
gen_loss += gen_gan_loss
# generator feature matching loss
gen_fm_loss = 0
for i in range(len(dis_real_feat)):
gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
gen_fm_loss = gen_fm_loss * self.args.fm_loss_weight
gen_loss += gen_fm_loss
# create logs
logs = [
("l_d1", dis_loss.item()),
("l_g1", gen_gan_loss.item()),
("l_fm", gen_fm_loss.item()),
]
return outputs, gen_loss, dis_loss, logs
class InpaintingModel(BaseModel):
def __init__(self, args):
super(InpaintingModel, self).__init__('InpaintingModel', args)
self.img_channel = args.img_channel
self.mask_channel = args.mask_channel
self.batch_size = args.batch_size
self.input_size = args.input_size
self.patch_size = args.patch_size
self.latent_vec_dim = args.latent_vec_dim
self.drop_rate = args.drop_rate
self.num_heads = args.num_heads
self.num_layers = args.num_layers
self.num_patches = int((self.input_size / self.patch_size) * (self.input_size / self.patch_size))
self.patch_vec_size = self.patch_size * self.patch_size * self.img_channel
self.mlp_hidden_dim = int(self.latent_vec_dim/2)
self.device = torch.device(args.device)
# generator input: [rgb(3) + edge(1), mask(1)]
# discriminator input: [rgb(3)]
self.generator = ImageGenerator(self.batch_size, self.img_channel, self.input_size, self.input_size, self.patch_size, self.patch_vec_size, self.num_patches,
self.latent_vec_dim, self.num_heads, self.mlp_hidden_dim, self.drop_rate, self.num_layers, self.device).to(self.device)
self.discriminator = Discriminator(in_channels=3, use_sigmoid=args.gan_loss != 'hinge').to(self.device)
self.l1_loss = nn.L1Loss()
self.perceptual_loss = PerceptualLoss()
self.style_loss = StyleLoss()
self.adversarial_loss = AdversarialLoss(type=args.gan_loss)
self.gen_optimizer = optim.Adam(params=self.generator.parameters(), lr=float(args.lr), betas=(args.beta1, args.beta2))
self.dis_optimizer = optim.Adam(params=self.discriminator.parameters(), lr=float(args.lr) * float(args.d2g_lr), betas=(args.beta1, args.beta2))
def forward(self, images, edges, masks):
images_masked = images * (1 - masks) # 원본 코드에선 images_masked = (images * (1 - masks).float()) + masks라고 했는데 왜인지 모르겠다.
inputs = torch.cat((images_masked, edges), dim = 1)
outputs, outputs_list, mask_list, attn_list = self.generator(inputs, masks)
return outputs
def backward(self, gen_loss=None, dis_loss=None):
dis_loss.backward()
gen_loss.backward()
self.dis_optimizer.step()
self.gen_optimizer.step()
def process(self, images, edges, masks):
gen_loss = 0
dis_loss = 0
self.gen_optimizer.zero_grad()
self.dis_optimizer.zero_grad()
outputs = self(images, edges, masks)
# discriminator loss
dis_input_real = images
dis_input_fake = outputs.detach()
dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)]
dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)]
dis_real_loss = self.adversarial_loss(dis_real, True, True)
dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
dis_loss = dis_loss + (dis_real_loss + dis_fake_loss) / 2
# generator adversarial loss
gen_input_fake = outputs
gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)]
gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.args.inpaint_adv_loss_weight
gen_loss = gen_loss + gen_gan_loss
# generator l1 loss
gen_l1_loss = self.l1_loss(outputs, images) * self.args.l1_loss_weight / torch.mean(masks)
gen_loss = gen_loss + gen_l1_loss
# generator perceptual loss
gen_content_loss = self.perceptual_loss(outputs, images)
gen_content_loss = gen_content_loss * self.args.content_loss_weight
gen_loss = gen_loss + gen_content_loss
# generator style loss
gen_style_loss = self.style_loss(outputs * masks, images * masks)
gen_style_loss = gen_style_loss * self.args.style_loss_weight
gen_loss = gen_loss + gen_style_loss
# create logs
logs = [
("l_d2", dis_loss.item()),
("l_g2", gen_gan_loss.item()),
("l_l1", gen_l1_loss.item()),
("l_per", gen_content_loss.item()),
("l_sty", gen_style_loss.item()),
]
return outputs, gen_loss, dis_loss, logs