티스토리 뷰

코드 분석/SGGLAT

args.py

상솜공방 2024. 2. 19. 13:11
import torch
import argparse

def get_args():
    parser = argparse.ArgumentParser()
    # 실험 세팅 제어
    parser.add_argument('--mode', type=int, default = 1,  choices=[1, 2], help='1: train, 2:test')
    parser.add_argument('--model', type=int, default = 2, choices=[1, 2, 3, 4], help='1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model')
    parser.add_argument('--mask', type=int, default=3, choices=[1, 2, 3, 4], help='1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)')
    parser.add_argument('--edge', type=int, default=1, choices=[1, 2], help='1: canny, 2: external')
    parser.add_argument('--seed', type=int, default=10, help='random seed')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='device to use for computation (cuda or cpu)')
    parser.add_argument('--max_epoch', type=int, default=1000, help='maximum number of iterations to train the model')

    # 디렉토리 제어
    parser.add_argument('--checkpoint_dir', type=str, default='./data/output/inpaint_model', help='model checkpoints path')
    parser.add_argument('--output_dir', type=str, default='./data/output/inpaint_model', help='path to the output directory')

    parser.add_argument('--train_img_dir', type=str, default = './data/train_img', help='path to the input images directory or an input image')
    parser.add_argument('--train_edge_dir', type=str, default = './data/train_edge', help='path to the edges directory or an edge file')
    parser.add_argument('--train_mask_dir', type=str, default = './data/train_mask', help='path to the masks directory or a mask file')

    parser.add_argument('--test_img_dir', type=str, default = './data/test_img', help='path to the input images directory or an input image')
    parser.add_argument('--test_edge_dir', type=str, default = './data/test_edge', help='path to the edges directory or an edge file')
    parser.add_argument('--test_mask_dir', type=str, default = './data/test_mask',help='path to the masks directory or a mask file')
    
    parser.add_argument('--valid_img_dir', type=str, default = './data/val_img', help='path to the input images directory or an input image')
    parser.add_argument('--valid_edge_dir', type=str, default = './data/val_edge', help='path to the edges directory or an edge file')
    parser.add_argument('--valid_mask_dir', type=str, default = './data/val_mask', help='path to the masks directory or a mask file')

    # 모델 구성 요소 제어
    parser.add_argument('--input_size', type=int, default=256, help='input image size for training 0 for original size')
    parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training')
    parser.add_argument('--patch_size', type=int, default=8, help='lenght of patch side')
    parser.add_argument('--img_channel', type=int, default=3, help='img_channel')
    parser.add_argument('--mask_channel', type=int, default=1, help='mask_channel')
    parser.add_argument('--latent_vec_dim', type=int, default=512, help='latent_vector_dim')
    parser.add_argument('--drop_rate', type=float, default=0.1, help='drop_rate')
    parser.add_argument('--num_heads', type=int, default=8, help='num_heads')
    parser.add_argument('--num_layers', type=int, default=11, help='num_layers')

    # 하이퍼 파라미터 제어
    parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
    parser.add_argument('--d2g_lr', type=float, default=0.1, help='discriminator/generator learning rate ratio')
    parser.add_argument('--beta1', type=float, default=0.0, help='adam optimizer beta1')
    parser.add_argument('--beta2', type=float, default=0.9, help='adam optimizer beta2')
    parser.add_argument('--sigma', type=int, default=2, help='standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)')
    parser.add_argument('--edge_threshold', type=float, default=0.5, help='edge detection threshold')
    parser.add_argument('--l1_loss_weight', type=float, default=1, help='l1 loss weight')
    parser.add_argument('--fm_loss_weight', type=int, default=10, help='feature-matching loss weight')
    parser.add_argument('--style_loss_weight', type=float, default=250, help='style loss weight')
    parser.add_argument('--content_loss_weight', type=float, default=0.1, help='perceptual loss weight')
    parser.add_argument('--inpaint_adv_loss_weight', type=float, default=0.1, help='adversarial loss weight')
    parser.add_argument('--gan_loss', type=str, default='nsgan', help='nsgan | lsgan | hinge')
    parser.add_argument('--gan_pool_size', type=int, default=0, help='fake images pool size')

    try:
        args = parser.parse_args() # 파이썬 모듈일 경우
    except:
        args = parser.parse_args(args=[]) # 주피터일 경우

    return args

'코드 분석 > SGGLAT' 카테고리의 다른 글

sgglat.py  (0) 2024.02.19
models.py  (0) 2024.02.19
dataset.ipynb  (0) 2024.02.13
SGGLAT_G.ipynb  (0) 2024.02.13
dataset.py  (0) 2024.02.13
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
TAG
more
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
글 보관함