티스토리 뷰

코드 분석/SGGLAT

sgglat.py

상솜공방 2024. 2. 19. 13:12
import os
import glob
import torch
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import DataLoader

from src.utils import imsave
from src.dataset import Dataset
from src.metrics import PSNR, EdgeAccuracy
from src.models import EdgeModel, InpaintingModel


class SGGLAT():

    def __init__(self, args):
        self.args = args
        if self.args.model == 1:
            model_name = 'edge'
        elif self.args.model == 2:
            model_name = 'inpaint'
        elif self.args.model == 3:
            model_name = 'edge_inpaint'
        elif self.args.model == 4:
            model_name = 'joint'

        # 모델 초기화
        self.model_name = model_name
        self.device = torch.device(self.args.device)
        self.edge_model = EdgeModel(self.args).to(self.device)
        self.inpaint_model = InpaintingModel(self.args).to(self.device)
        self.edgeacc = EdgeAccuracy(self.args.edge_threshold).to(self.device)
        self.psnr = PSNR(255.0).to(self.device)

        # 데이터셋 초기화
        if self.args.mode == 2:
            self.test_dataset = Dataset(self.args, self.args.test_img_dir, self.args.test_edge_dir, self.args.test_mask_dir, augment=False, training=False)
        else:
            self.train_dataset = Dataset(self.args, self.args.train_img_dir, self.args.train_edge_dir, self.args.train_mask_dir, augment=True, training=True)
            self.val_dataset = Dataset(self.args, self.args.valid_img_dir, self.args.valid_edge_dir, self.args.valid_mask_dir, augment=False, training=True)
            #self.sample_iterator = self.val_dataset.create_iterator(self.args.sample_size) # sef sample 함수에서 사용

        # 기록할 정보 초기화
        self.train_logs_dict = {}
        self.val_logs_dict = {}
        self.samples_path = os.path.join(self.args.output_dir, 'validation_samples') 
        self.samples_outputs_only_path = os.path.join(self.args.output_dir, 'validation_samples_outputs_only')
        self.results_path = os.path.join(self.args.output_dir, 'test_results')
        self.plots_path = os.path.join(self.args.output_dir, 'training_plots')
        self.train_log_file = os.path.join(self.args.output_dir, 'train_log_' + model_name + '.txt')
        self.val_log_file = os.path.join(self.args.output_dir, 'val_log_' + model_name + '.txt')

    def load(self):
        if self.args.model == 1:
            self.edge_model.load()
        elif self.args.model == 2:
            self.inpaint_model.load()
        else:
            self.edge_model.load()
            self.inpaint_model.load()

    def save(self):
        if self.args.model == 1:
            self.edge_model.save()
        elif self.args.model == 2:
            self.inpaint_model.save()
        elif self.args.model == 3:
            self.edge_model.save()
            self.inpaint_model.save()
        elif self.args.model == 4:
            self.edge_model.save()
            self.inpaint_model.save()

        else:
            print("No valid model type specified for saving.")

    def train(self):
        train_loader = DataLoader(dataset=self.train_dataset, batch_size=self.args.batch_size, num_workers=4, drop_last=True, shuffle=True)
        epoch = 0
        keep_training = True
        model = self.args.model
        total = len(self.train_dataset)
        print("Training data lengh: {}".format(total))

        max_epoch = int(float((self.args.max_epoch)))

        while(keep_training):
            epoch += 1
            self.edge_model.train()
            self.inpaint_model.train()

            for items in tqdm(train_loader, desc=f"Epoch {epoch}", total=len(train_loader)):
                images, images_gray, edges, masks = [item.to(self.device) for item in items]
            
                if model == 1:
                    outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks)
                    precision, recall, f1_score = self.edgeacc(edges * masks, outputs * masks)
                    logs.append(('precision', precision.item()))
                    logs.append(('recall', recall.item()))
                    logs.append(('f1_score', f1_score.item()))
                    self.edge_model.backward(gen_loss, dis_loss)
                    iteration = self.edge_model.iteration

                if model == 2:
                    outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks)
                    outputs_merged = (outputs * masks) + (images * (1 - masks))
                    psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
                    mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
                    logs.append(('psnr', psnr.item()))
                    logs.append(('mae', mae.item()))
                    self.inpaint_model.backward(gen_loss, dis_loss)
                    iteration = self.inpaint_model.iteration

            logs = [("epoch", epoch)] + logs

            self.add_logs_to_dict(self.train_logs_dict, logs)
            self.save_log(self.train_log_file, logs)
            self.eval()
            self.save_plot()
            
            if epoch >= max_epoch:
                keep_training = False
                break

        print('End training...')

    def eval(self):
        val_loader = DataLoader(dataset=self.val_dataset, batch_size=self.args.batch_size, drop_last=True, shuffle=True)
        model = self.args.model
        self.edge_model.eval()
        self.inpaint_model.eval()

        for items in val_loader:
            images, images_gray, edges, masks = [item.to(self.device) for item in items]

            if model == 1:
                outputs, _, _, logs = self.edge_model.process(images_gray, edges, masks)
                outputs_merged = (outputs * masks) + (edges * (1 - masks))
                precision, recall, f1_score = self.edgeacc(edges * masks, outputs * masks)
                logs.append(('precision', precision.item()))
                logs.append(('recall', recall.item()))
                logs.append(('f1_score', f1_score.item()))

            elif model == 2:
                outputs, _, _, logs = self.inpaint_model.process(images, edges, masks)
                outputs_merged = (outputs * masks) + (images * (1 - masks))
                psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
                mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
                logs.append(('psnr', psnr.item()))
                logs.append(('mae', mae.item()))
        
        self.add_logs_to_dict(self.val_logs_dict, logs)
        self.save_log(self.val_log_file, logs)
        self.save_batch_images([images, images_gray, edges, masks, outputs, outputs_merged],
                               ['images', 'images_gray', 'edges', 'masks', 'outputs', 'outputs_merged'],
                               self.samples_path)
        is_gray_scale = True if model == 1 else False
        self.save_single_image(outputs[0], self.samples_outputs_only_path, (4, 4), is_gray_scale)
        
        # 모델 저장
        if model == 1 and 'f1_score' in self.val_logs_dict and len(self.val_logs_dict['f1_score']) > 1:
            if self.val_logs_dict['f1_score'][-1] > self.val_logs_dict['f1_score'][-2]:
                print("Model 1: F1 Score improved.", end = " ")
                self.save()

        elif model == 2 and 'psnr' in self.val_logs_dict and len(self.val_logs_dict['psnr']) > 1:
            if self.val_logs_dict['psnr'][-1] > self.val_logs_dict['psnr'][-2]:
                print("Model 2: PSNR improved.", end = " ")
                self.save()
        
    def test(self):
        test_loader = DataLoader(dataset=self.test_dataset, batch_size=1)
        model = self.args.model
        self.edge_model.eval()
        self.inpaint_model.eval()
        print("Model:", model)

        index = 0
        for items in test_loader:
            name = self.test_dataset.load_name(index)
            images, images_gray, edges, masks = [item.to(self.device) for item in items]
            index += 1

            if model == 1:
                outputs, _, _, _ = self.edge_model.process(images_gray, edges, masks)
                outputs_merged = (outputs * masks) + (edges * (1 - masks))
            elif model == 2:
                outputs, _, _, _ = self.inpaint_model.process(images, edges, masks)
                outputs_merged = (outputs * masks) + (images * (1 - masks))
            
            output = self.postprocess(outputs_merged)[0]
            path = os.path.join(self.results_path, name)
            print(index, name)
            imsave(output, path)

    #===================================================== Utils =====================================================#

    def add_logs_to_dict(self, list_logs, logs):
        for log in logs:
            key = log[0]
            value = log[1]
            if key in list_logs:
                list_logs[key].append(value)
            else:
                list_logs[key] = [value]

    def postprocess(self, img):
        img = img * 255.0 # [0, 1] => [0, 255]
        img = img.permute(0, 2, 3, 1)
        return img.int()
    
    def save_log(self, log_file, logs):
        log_str = ', '.join([f"{name}: {value}" for name, value in logs])
        with open(log_file, 'a') as f:
            f.write(log_str + '\n')

    def save_plot(self):
        common_keys = set(self.train_logs_dict.keys()) & set(self.val_logs_dict.keys())

        for key in common_keys:
            save_dir = os.path.join(self.plots_path, f'{key}.jpg')
            plt.figure(figsize=(10, 6))
            plt.plot(self.train_logs_dict[key], label='Train', marker='o')
            plt.plot(self.val_logs_dict[key], label='Validation', marker='x')
            plt.title(f'{key} over epochs')
            plt.xlabel('Epoch')
            plt.ylabel(key)
            plt.legend()
            plt.savefig(save_dir)
            plt.close()

    def save_batch_images(self, tensors, tensor_names, save_dir, figsize_per_image=(3, 3)):
        B, C, H, W = tensors[0].shape
        num_tensors = len(tensors)
        fig_width = figsize_per_image[0] * num_tensors
        fig_height = figsize_per_image[1] * B
        fig, axs = plt.subplots(B, num_tensors, figsize=(fig_width, fig_height), squeeze=False)
        
        for i in range(B):
            for j, (tensor, name) in enumerate(zip(tensors, tensor_names)):
                ax = axs[i, j]
                img = tensor[i].permute(1, 2, 0).cpu().detach().numpy() if C == 3 else tensor[i].squeeze().cpu().detach().numpy()
                ax.imshow(img, cmap='gray' if tensor.shape[1] == 1 else None)
                if j == 0:
                    ax.set_ylabel(f"Batch {i+1}", rotation=0, size='large', labelpad=20)
                if i == 0:
                    ax.set_title(name)
                ax.axis('off')
        
        existing_files = glob.glob(os.path.join(save_dir, '*.jpg'))
        next_file_number = len(existing_files) + 1
        save_dir = os.path.join(save_dir, f"{str(next_file_number).zfill(5)}.jpg")
        plt.savefig(save_dir)
        plt.close()

    def save_single_image(self, tensor, save_dir, figsize=(4, 4), is_gray_scale=True):
        # tensor: [C, H, W] 텐서
        # figsize: 플롯 전체 크기의 (너비, 높이)
        # is_gray_scale: 흑백 마스크인지 컬러 이미지인지 여부
        plt.figure(figsize=figsize)

        if is_gray_scale:  # 흑백 이미지 처리
            img = tensor.squeeze().cpu().detach().numpy()  # i번째 이미지를 선택하고, 채널 차원을 제거합니다.
            plt.imshow(img, cmap='gray')
        else:  # 컬러 이미지 처리
            img = tensor.permute(1, 2, 0).cpu().detach().numpy()  # [C, H, W] -> [H, W, C]로 차원 변경
            plt.imshow(img)

        plt.axis('off')
        existing_files = glob.glob(os.path.join(save_dir, '*.jpg'))
        next_file_number = len(existing_files) + 1

        save_dir = os.path.join(save_dir, f"{str(next_file_number).zfill(5)}.jpg")
        epoch_text = f"epoch: {str(next_file_number).zfill(5)}"
        plt.title(epoch_text)
        plt.savefig(save_dir)
        plt.close()

'코드 분석 > SGGLAT' 카테고리의 다른 글

models.py  (0) 2024.02.19
args.py  (0) 2024.02.19
dataset.ipynb  (0) 2024.02.13
SGGLAT_G.ipynb  (0) 2024.02.13
dataset.py  (0) 2024.02.13
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
TAG
more
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
글 보관함