import torch
import argparse
def get_args():
parser = argparse.ArgumentParser()
# 실험 세팅 제어
parser.add_argument('--mode', type=int, default = 1, choices=[1, 2], help='1: train, 2:test')
parser.add_argument('--model', type=int, default = 2, choices=[1, 2, 3, 4], help='1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model')
parser.add_argument('--mask', type=int, default=3, choices=[1, 2, 3, 4], help='1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)')
parser.add_argument('--edge', type=int, default=1, choices=[1, 2], help='1: canny, 2: external')
parser.add_argument('--seed', type=int, default=10, help='random seed')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='device to use for computation (cuda or cpu)')
parser.add_argument('--max_epoch', type=int, default=1000, help='maximum number of iterations to train the model')
# 디렉토리 제어
parser.add_argument('--checkpoint_dir', type=str, default='./data/output/inpaint_model', help='model checkpoints path')
parser.add_argument('--output_dir', type=str, default='./data/output/inpaint_model', help='path to the output directory')
parser.add_argument('--train_img_dir', type=str, default = './data/train_img', help='path to the input images directory or an input image')
parser.add_argument('--train_edge_dir', type=str, default = './data/train_edge', help='path to the edges directory or an edge file')
parser.add_argument('--train_mask_dir', type=str, default = './data/train_mask', help='path to the masks directory or a mask file')
parser.add_argument('--test_img_dir', type=str, default = './data/test_img', help='path to the input images directory or an input image')
parser.add_argument('--test_edge_dir', type=str, default = './data/test_edge', help='path to the edges directory or an edge file')
parser.add_argument('--test_mask_dir', type=str, default = './data/test_mask',help='path to the masks directory or a mask file')
parser.add_argument('--valid_img_dir', type=str, default = './data/val_img', help='path to the input images directory or an input image')
parser.add_argument('--valid_edge_dir', type=str, default = './data/val_edge', help='path to the edges directory or an edge file')
parser.add_argument('--valid_mask_dir', type=str, default = './data/val_mask', help='path to the masks directory or a mask file')
# 모델 구성 요소 제어
parser.add_argument('--input_size', type=int, default=256, help='input image size for training 0 for original size')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training')
parser.add_argument('--patch_size', type=int, default=8, help='lenght of patch side')
parser.add_argument('--img_channel', type=int, default=3, help='img_channel')
parser.add_argument('--mask_channel', type=int, default=1, help='mask_channel')
parser.add_argument('--latent_vec_dim', type=int, default=512, help='latent_vector_dim')
parser.add_argument('--drop_rate', type=float, default=0.1, help='drop_rate')
parser.add_argument('--num_heads', type=int, default=8, help='num_heads')
parser.add_argument('--num_layers', type=int, default=11, help='num_layers')
# 하이퍼 파라미터 제어
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--d2g_lr', type=float, default=0.1, help='discriminator/generator learning rate ratio')
parser.add_argument('--beta1', type=float, default=0.0, help='adam optimizer beta1')
parser.add_argument('--beta2', type=float, default=0.9, help='adam optimizer beta2')
parser.add_argument('--sigma', type=int, default=2, help='standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)')
parser.add_argument('--edge_threshold', type=float, default=0.5, help='edge detection threshold')
parser.add_argument('--l1_loss_weight', type=float, default=1, help='l1 loss weight')
parser.add_argument('--fm_loss_weight', type=int, default=10, help='feature-matching loss weight')
parser.add_argument('--style_loss_weight', type=float, default=250, help='style loss weight')
parser.add_argument('--content_loss_weight', type=float, default=0.1, help='perceptual loss weight')
parser.add_argument('--inpaint_adv_loss_weight', type=float, default=0.1, help='adversarial loss weight')
parser.add_argument('--gan_loss', type=str, default='nsgan', help='nsgan | lsgan | hinge')
parser.add_argument('--gan_pool_size', type=int, default=0, help='fake images pool size')
try:
args = parser.parse_args() # 파이썬 모듈일 경우
except:
args = parser.parse_args(args=[]) # 주피터일 경우
return args