import os
import glob
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.utils import imsave
from src.dataset import Dataset
from src.metrics import PSNR, EdgeAccuracy
from src.models import EdgeModel, InpaintingModel
class SGGLAT():
def __init__(self, args):
self.args = args
if self.args.model == 1:
model_name = 'edge'
elif self.args.model == 2:
model_name = 'inpaint'
elif self.args.model == 3:
model_name = 'edge_inpaint'
elif self.args.model == 4:
model_name = 'joint'
self.model_name = model_name
self.device = torch.device(self.args.device)
self.edge_model = EdgeModel(self.args).to(self.device)
self.inpaint_model = InpaintingModel(self.args).to(self.device)
self.edgeacc = EdgeAccuracy(self.args.edge_threshold).to(self.device)
self.psnr = PSNR(255.0).to(self.device)
if self.args.mode == 2:
self.test_dataset = Dataset(self.args, self.args.test_img_dir, self.args.test_edge_dir, self.args.test_mask_dir, augment=False, training=False)
else:
self.train_dataset = Dataset(self.args, self.args.train_img_dir, self.args.train_edge_dir, self.args.train_mask_dir, augment=True, training=True)
self.val_dataset = Dataset(self.args, self.args.valid_img_dir, self.args.valid_edge_dir, self.args.valid_mask_dir, augment=False, training=True)
self.train_logs_dict = {}
self.val_logs_dict = {}
self.samples_path = os.path.join(self.args.output_dir, 'validation_samples')
self.samples_outputs_only_path = os.path.join(self.args.output_dir, 'validation_samples_outputs_only')
self.results_path = os.path.join(self.args.output_dir, 'test_results')
self.plots_path = os.path.join(self.args.output_dir, 'training_plots')
self.train_log_file = os.path.join(self.args.output_dir, 'train_log_' + model_name + '.txt')
self.val_log_file = os.path.join(self.args.output_dir, 'val_log_' + model_name + '.txt')
def load(self):
if self.args.model == 1:
self.edge_model.load()
elif self.args.model == 2:
self.inpaint_model.load()
else:
self.edge_model.load()
self.inpaint_model.load()
def save(self):
if self.args.model == 1:
self.edge_model.save()
elif self.args.model == 2:
self.inpaint_model.save()
elif self.args.model == 3:
self.edge_model.save()
self.inpaint_model.save()
elif self.args.model == 4:
self.edge_model.save()
self.inpaint_model.save()
else:
print("No valid model type specified for saving.")
def train(self):
train_loader = DataLoader(dataset=self.train_dataset, batch_size=self.args.batch_size, num_workers=4, drop_last=True, shuffle=True)
epoch = 0
keep_training = True
model = self.args.model
total = len(self.train_dataset)
print("Training data lengh: {}".format(total))
max_epoch = int(float((self.args.max_epoch)))
while(keep_training):
epoch += 1
self.edge_model.train()
self.inpaint_model.train()
for items in tqdm(train_loader, desc=f"Epoch {epoch}", total=len(train_loader)):
images, images_gray, edges, masks = [item.to(self.device) for item in items]
if model == 1:
outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks)
precision, recall, f1_score = self.edgeacc(edges * masks, outputs * masks)
logs.append(('precision', precision.item()))
logs.append(('recall', recall.item()))
logs.append(('f1_score', f1_score.item()))
self.edge_model.backward(gen_loss, dis_loss)
iteration = self.edge_model.iteration
if model == 2:
outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks)
outputs_merged = (outputs * masks) + (images * (1 - masks))
psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
logs.append(('psnr', psnr.item()))
logs.append(('mae', mae.item()))
self.inpaint_model.backward(gen_loss, dis_loss)
iteration = self.inpaint_model.iteration
logs = [("epoch", epoch)] + logs
self.add_logs_to_dict(self.train_logs_dict, logs)
self.save_log(self.train_log_file, logs)
self.eval()
self.save_plot()
if epoch >= max_epoch:
keep_training = False
break
print('End training...')
def eval(self):
val_loader = DataLoader(dataset=self.val_dataset, batch_size=self.args.batch_size, drop_last=True, shuffle=True)
model = self.args.model
self.edge_model.eval()
self.inpaint_model.eval()
for items in val_loader:
images, images_gray, edges, masks = [item.to(self.device) for item in items]
if model == 1:
outputs, _, _, logs = self.edge_model.process(images_gray, edges, masks)
outputs_merged = (outputs * masks) + (edges * (1 - masks))
precision, recall, f1_score = self.edgeacc(edges * masks, outputs * masks)
logs.append(('precision', precision.item()))
logs.append(('recall', recall.item()))
logs.append(('f1_score', f1_score.item()))
elif model == 2:
outputs, _, _, logs = self.inpaint_model.process(images, edges, masks)
outputs_merged = (outputs * masks) + (images * (1 - masks))
psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
logs.append(('psnr', psnr.item()))
logs.append(('mae', mae.item()))
self.add_logs_to_dict(self.val_logs_dict, logs)
self.save_log(self.val_log_file, logs)
self.save_batch_images([images, images_gray, edges, masks, outputs, outputs_merged],
['images', 'images_gray', 'edges', 'masks', 'outputs', 'outputs_merged'],
self.samples_path)
is_gray_scale = True if model == 1 else False
self.save_single_image(outputs[0], self.samples_outputs_only_path, (4, 4), is_gray_scale)
if model == 1 and 'f1_score' in self.val_logs_dict and len(self.val_logs_dict['f1_score']) > 1:
if self.val_logs_dict['f1_score'][-1] > self.val_logs_dict['f1_score'][-2]:
print("Model 1: F1 Score improved.", end = " ")
self.save()
elif model == 2 and 'psnr' in self.val_logs_dict and len(self.val_logs_dict['psnr']) > 1:
if self.val_logs_dict['psnr'][-1] > self.val_logs_dict['psnr'][-2]:
print("Model 2: PSNR improved.", end = " ")
self.save()
def test(self):
test_loader = DataLoader(dataset=self.test_dataset, batch_size=1)
model = self.args.model
self.edge_model.eval()
self.inpaint_model.eval()
print("Model:", model)
index = 0
for items in test_loader:
name = self.test_dataset.load_name(index)
images, images_gray, edges, masks = [item.to(self.device) for item in items]
index += 1
if model == 1:
outputs, _, _, _ = self.edge_model.process(images_gray, edges, masks)
outputs_merged = (outputs * masks) + (edges * (1 - masks))
elif model == 2:
outputs, _, _, _ = self.inpaint_model.process(images, edges, masks)
outputs_merged = (outputs * masks) + (images * (1 - masks))
output = self.postprocess(outputs_merged)[0]
path = os.path.join(self.results_path, name)
print(index, name)
imsave(output, path)
def add_logs_to_dict(self, list_logs, logs):
for log in logs:
key = log[0]
value = log[1]
if key in list_logs:
list_logs[key].append(value)
else:
list_logs[key] = [value]
def postprocess(self, img):
img = img * 255.0
img = img.permute(0, 2, 3, 1)
return img.int()
def save_log(self, log_file, logs):
log_str = ', '.join([f"{name}: {value}" for name, value in logs])
with open(log_file, 'a') as f:
f.write(log_str + '\n')
def save_plot(self):
common_keys = set(self.train_logs_dict.keys()) & set(self.val_logs_dict.keys())
for key in common_keys:
save_dir = os.path.join(self.plots_path, f'{key}.jpg')
plt.figure(figsize=(10, 6))
plt.plot(self.train_logs_dict[key], label='Train', marker='o')
plt.plot(self.val_logs_dict[key], label='Validation', marker='x')
plt.title(f'{key} over epochs')
plt.xlabel('Epoch')
plt.ylabel(key)
plt.legend()
plt.savefig(save_dir)
plt.close()
def save_batch_images(self, tensors, tensor_names, save_dir, figsize_per_image=(3, 3)):
B, C, H, W = tensors[0].shape
num_tensors = len(tensors)
fig_width = figsize_per_image[0] * num_tensors
fig_height = figsize_per_image[1] * B
fig, axs = plt.subplots(B, num_tensors, figsize=(fig_width, fig_height), squeeze=False)
for i in range(B):
for j, (tensor, name) in enumerate(zip(tensors, tensor_names)):
ax = axs[i, j]
img = tensor[i].permute(1, 2, 0).cpu().detach().numpy() if C == 3 else tensor[i].squeeze().cpu().detach().numpy()
ax.imshow(img, cmap='gray' if tensor.shape[1] == 1 else None)
if j == 0:
ax.set_ylabel(f"Batch {i+1}", rotation=0, size='large', labelpad=20)
if i == 0:
ax.set_title(name)
ax.axis('off')
existing_files = glob.glob(os.path.join(save_dir, '*.jpg'))
next_file_number = len(existing_files) + 1
save_dir = os.path.join(save_dir, f"{str(next_file_number).zfill(5)}.jpg")
plt.savefig(save_dir)
plt.close()
def save_single_image(self, tensor, save_dir, figsize=(4, 4), is_gray_scale=True):
plt.figure(figsize=figsize)
if is_gray_scale:
img = tensor.squeeze().cpu().detach().numpy()
plt.imshow(img, cmap='gray')
else:
img = tensor.permute(1, 2, 0).cpu().detach().numpy()
plt.imshow(img)
plt.axis('off')
existing_files = glob.glob(os.path.join(save_dir, '*.jpg'))
next_file_number = len(existing_files) + 1
save_dir = os.path.join(save_dir, f"{str(next_file_number).zfill(5)}.jpg")
epoch_text = f"epoch: {str(next_file_number).zfill(5)}"
plt.title(epoch_text)
plt.savefig(save_dir)
plt.close()