티스토리 뷰

코드 분석/SGGLAT

models.py

상솜공방 2024. 2. 19. 13:12
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from .generator import EdgeGenerator, ImageGenerator
from .edge_connect_discriminator import Discriminator
from .loss import AdversarialLoss, PerceptualLoss, StyleLoss


class BaseModel(nn.Module):

    def __init__(self, name, args):
        super(BaseModel, self).__init__()
        self.name = name
        self.args = args
        self.iteration = 0
        self.gen_weights_path = os.path.join(args.checkpoint_dir, name + '_gen.pth') # ./data/output/~_gen.pth
        self.dis_weights_path = os.path.join(args.checkpoint_dir, name + '_dis.pth') # ./data/output/~_dis.pth

    def load(self):
        if os.path.exists(self.gen_weights_path):
            if torch.cuda.is_available():
                data = torch.load(self.gen_weights_path)
            else:
                data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)
            self.generator.load_state_dict(data['generator']) # 베이스 모델에선 self.generator를 선언하지 않으나, 상속받는 하단의 클래스에서 생성된다.
            self.iteration = data['iteration']
            print("{} generator successfully loaded".format(self.name))

        if os.path.exists(self.dis_weights_path) and self.args.mode == 1:
            if torch.cuda.is_available():
                data = torch.load(self.dis_weights_path)
            else:
                data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)
            self.discriminator.load_state_dict(data['discriminator']) # 베이스 모델에선 self.discriminator를 선언하지 않으나, 상속받는 하단의 클래스에서 생성된다.
            print("{} discriminator successfully loaded".format(self.name))
            
    def save(self):
        print('Saving {}...\n'.format(self.name))
        torch.save({'iteration': self.iteration, 'generator': self.generator.state_dict()}, self.gen_weights_path)
        torch.save({'discriminator': self.discriminator.state_dict()}, self.dis_weights_path)

    def show_batch_images(self, tensors, tensor_names, figsize_per_image=(3, 3)):
        B, C, H, W = tensors[0].shape
        num_tensors = len(tensors)
        
        # 전체 도표의 크기 계산
        fig_width = figsize_per_image[0] * num_tensors
        fig_height = figsize_per_image[1] * B
        
        # 전체 도표 생성
        fig, axs = plt.subplots(B, num_tensors, figsize=(fig_width, fig_height), squeeze=False)
        
        for i in range(B):
            for j, (tensor, name) in enumerate(zip(tensors, tensor_names)):
                ax = axs[i, j]
                img = tensor[i].permute(1, 2, 0).cpu().detach().numpy() if C == 3 else tensor[i].squeeze().cpu().detach().numpy()
                ax.imshow(img, cmap='gray' if tensor.shape[1] == 1 else None)
                # 첫 번째 열의 각 행에만 텐서 이름을 제목으로 설정
                if j == 0:
                    ax.set_ylabel(f"Batch {i+1}", rotation=0, size='large', labelpad=20)
                if i == 0:
                    ax.set_title(name)
                ax.axis('off')
        
        plt.tight_layout()
        plt.show()


class EdgeModel(BaseModel):

    def __init__(self, args):
        # args 인수를 넘겨주며 부모 클래스 초기화
        super(EdgeModel, self).__init__('EdgeModel', args)

        # args에 있는 정보를 인스턴스 변수로 저장
        self.img_channel = args.img_channel
        self.mask_channel = args.mask_channel
        self.batch_size = args.batch_size
        self.input_size = args.input_size
        self.patch_size = args.patch_size
        self.latent_vec_dim = args.latent_vec_dim
        self.drop_rate = args.drop_rate
        self.num_heads = args.num_heads
        self.num_layers = args.num_layers
        self.num_patches = int((self.input_size / self.patch_size) * (self.input_size / self.patch_size))
        self.patch_vec_size = self.patch_size * self.patch_size * self.img_channel
        self.mlp_hidden_dim = int(self.latent_vec_dim/2)
        self.device = torch.device(args.device)
        
        # 생성기, 판별기, 오차 함수 초기화
        # generator input: [grayscale(1) + edge(1), mask(1)]
        # discriminator input: [grayscale(1) + edge(1)]
        self.generator = EdgeGenerator(self.batch_size, self.img_channel, self.input_size, self.input_size, self.patch_size, self.patch_vec_size,self.num_patches,
                                  self.latent_vec_dim, self.num_heads, self.mlp_hidden_dim, self.drop_rate, self.num_heads, self.device).to(self.device)
        self.discriminator = Discriminator(in_channels=2, use_sigmoid=args.gan_loss != 'hinge').to(self.device)
        self.l1_loss = nn.L1Loss()
        self.adversarial_loss = AdversarialLoss(type=args.gan_loss)

        # 옵티마이저 설정 초기화
        self.gen_optimizer = optim.Adam(params=self.generator.parameters(), lr=float(args.lr), betas=(args.beta1, args.beta2))
        self.dis_optimizer = optim.Adam(params=self.discriminator.parameters(),lr=float(args.lr) * float(args.d2g_lr),betas=(args.beta1, args.beta2))

    def forward(self, images, edges, masks):
        images_masked = (images * (1 - masks)) # 원본 코드: images_masked = (images * (1 - masks)) + masks
        edges_masked = (edges * (1 - masks))
        input_images = torch.cat((images_masked, edges_masked), dim = 1)
        outputs, outputs_list, mask_list, attn_list = self.generator(input_images, masks)
        # self.show_batch_images([images_masked, edges_masked, outputs], ['images_masked', 'edges_masked', 'outputs']) <== 마스크 된 이미지와 엣지, 그리고 아웃풋 데이터를 가시화
        return outputs
    
    def backward(self, gen_loss = None, dis_loss = None):
        if dis_loss is not None:
            dis_loss.backward()

        if gen_loss is not None:
            gen_loss.backward()
        
        self.dis_optimizer.step()
        self.gen_optimizer.step()

    def process(self, images, edges, masks):
        self.iteration += 1
        gen_loss = 0
        dis_loss = 0
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()

        # self.show_batch_images([images, edges, masks], ['original_images', 'original_edges', 'masks']) <== 인풋 데이터 가시화

        outputs = self(images, edges, masks) # 인스턴스를 호출하면 인스턴스.forward() 함수가 자동으로 동작
        
        # discriminator loss
        dis_input_real = torch.cat((images, edges), dim=1)
        dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
        dis_real, dis_real_feat = self.discriminator(dis_input_real)        # in: (grayscale(1) + edge(1))
        dis_fake, dis_fake_feat = self.discriminator(dis_input_fake)        # in: (grayscale(1) + edge(1))
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2

        # generator adversarial loss
        gen_input_fake = torch.cat((images, outputs), dim=1)
        gen_fake, gen_fake_feat = self.discriminator(gen_input_fake)        # in: (grayscale(1) + edge(1))
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
        gen_loss += gen_gan_loss

        # generator feature matching loss
        gen_fm_loss = 0
        for i in range(len(dis_real_feat)):
            gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
        gen_fm_loss = gen_fm_loss * self.args.fm_loss_weight
        gen_loss += gen_fm_loss

        # create logs
        logs = [
            ("l_d1", dis_loss.item()),
            ("l_g1", gen_gan_loss.item()),
            ("l_fm", gen_fm_loss.item()),
        ]

        return outputs, gen_loss, dis_loss, logs


class InpaintingModel(BaseModel):

    def __init__(self, args):
        super(InpaintingModel, self).__init__('InpaintingModel', args)
        self.img_channel = args.img_channel
        self.mask_channel = args.mask_channel
        self.batch_size = args.batch_size
        self.input_size = args.input_size
        self.patch_size = args.patch_size
        self.latent_vec_dim = args.latent_vec_dim
        self.drop_rate = args.drop_rate
        self.num_heads = args.num_heads
        self.num_layers = args.num_layers
        self.num_patches = int((self.input_size / self.patch_size) * (self.input_size / self.patch_size))
        self.patch_vec_size = self.patch_size * self.patch_size * self.img_channel
        self.mlp_hidden_dim = int(self.latent_vec_dim/2)
        self.device = torch.device(args.device)

        # generator input: [rgb(3) + edge(1), mask(1)]
        # discriminator input: [rgb(3)]
        self.generator = ImageGenerator(self.batch_size, self.img_channel, self.input_size, self.input_size, self.patch_size, self.patch_vec_size, self.num_patches,
                                        self.latent_vec_dim, self.num_heads, self.mlp_hidden_dim, self.drop_rate, self.num_layers, self.device).to(self.device)
        self.discriminator = Discriminator(in_channels=3, use_sigmoid=args.gan_loss != 'hinge').to(self.device)
        self.l1_loss = nn.L1Loss()
        self.perceptual_loss = PerceptualLoss()
        self.style_loss = StyleLoss()
        self.adversarial_loss = AdversarialLoss(type=args.gan_loss)
        self.gen_optimizer = optim.Adam(params=self.generator.parameters(), lr=float(args.lr), betas=(args.beta1, args.beta2))
        self.dis_optimizer = optim.Adam(params=self.discriminator.parameters(), lr=float(args.lr) * float(args.d2g_lr), betas=(args.beta1, args.beta2))

    def forward(self, images, edges, masks):
        images_masked = images * (1 - masks) # 원본 코드에선 images_masked = (images * (1 - masks).float()) + masks라고 했는데 왜인지 모르겠다.
        inputs = torch.cat((images_masked, edges), dim = 1)
        outputs, outputs_list, mask_list, attn_list = self.generator(inputs, masks)
        return outputs
    
    def backward(self, gen_loss=None, dis_loss=None):
        dis_loss.backward()
        gen_loss.backward()
        self.dis_optimizer.step()
        self.gen_optimizer.step()

    def process(self, images, edges, masks):
        gen_loss = 0
        dis_loss = 0
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()

        outputs = self(images, edges, masks)

        # discriminator loss
        dis_input_real = images
        dis_input_fake = outputs.detach()
        dis_real, _ = self.discriminator(dis_input_real)                    # in: [rgb(3)]
        dis_fake, _ = self.discriminator(dis_input_fake)                    # in: [rgb(3)]
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss = dis_loss + (dis_real_loss + dis_fake_loss) / 2

        # generator adversarial loss
        gen_input_fake = outputs
        gen_fake, _ = self.discriminator(gen_input_fake)                    # in: [rgb(3)]
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.args.inpaint_adv_loss_weight
        gen_loss = gen_loss + gen_gan_loss

        # generator l1 loss
        gen_l1_loss = self.l1_loss(outputs, images) * self.args.l1_loss_weight / torch.mean(masks)
        gen_loss = gen_loss + gen_l1_loss

        # generator perceptual loss
        gen_content_loss = self.perceptual_loss(outputs, images)
        gen_content_loss = gen_content_loss * self.args.content_loss_weight
        gen_loss = gen_loss + gen_content_loss

        # generator style loss
        gen_style_loss = self.style_loss(outputs * masks, images * masks)
        gen_style_loss = gen_style_loss * self.args.style_loss_weight
        gen_loss = gen_loss + gen_style_loss

        # create logs
        logs = [
            ("l_d2", dis_loss.item()),
            ("l_g2", gen_gan_loss.item()),
            ("l_l1", gen_l1_loss.item()),
            ("l_per", gen_content_loss.item()),
            ("l_sty", gen_style_loss.item()),
        ]

        return outputs, gen_loss, dis_loss, logs

'코드 분석 > SGGLAT' 카테고리의 다른 글

sgglat.py  (0) 2024.02.19
args.py  (0) 2024.02.19
dataset.ipynb  (0) 2024.02.13
SGGLAT_G.ipynb  (0) 2024.02.13
dataset.py  (0) 2024.02.13
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
TAG
more
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
글 보관함