코드 분석/Edge-Connect

Edge-Connect: 모듈 분석

상솜공방 2024. 1. 30. 11:38

나만의 머신러닝 코드를 작성하기 위해선 기존에 완성된 코드를 분석하는 것이 우선이라고 생각하였다. 따라서 2019년에 발표된 edge connect 논문의 코드를 먼저 깊게 분석하기로 한다.

 

메인 함수

더보기

main.py

개요 모델을 불러와서 argument를 넣고 train 혹은 test를 진행하는 함수
보조 모듈 config.py, edge_conncect.py
함수 main(mode = None)
1. arguements 받기 (import config.py)
2. CUDA 설정
3. 랜덤 시드 설정
4. 모델 초기화 (import edge_connect.py)
5. train/test/eval에 따른 실험 진행
load_config(mode = None)
config.yml 파일을 받아서 argparse 함수를 이용해 필요한 옵션을 모두 세팅

 

 

아규먼트를 바탕으로 모델 훈련 및 시험에 대한 전체 시퀀스 제어 

더보기

config.py

개요 yml 파일이 있는 경로를 받아서 읽어온 후, 이를 바탕으로 dict 파일을 생성
클래스 Config __init__(self, config_path)
yml 파일의 경로를 받아 이를 dict 파일로 생성
__getattr__(self, name)
딕셔너리의 정보에 따라 dict[name]을 반환
print(self)
모델에 대한 yml 파일을 콘솔에 출력

 

edge_connect.py

개요 모델을 불러오고, 이에 대한 훈련 및 테스트 등의 일괄적인 프로세스를 진행
보조 모듈 dataset.py: 커스텀 데이터셋을 만들 때 사용
models.py: 모델을 구성할 때 사용
metrics.py: 메트릭을 구할 때 사용
utils.py: 진척도, 디렉토리 생성, 이미지 스티칭, 이미지 저장 등의 기능에 사용
클래스




Edge
Connect





__init__(self, config)
config 파일을 받아 다음 일련의 수행을 초기화
1. 어떤 모델을 돌릴 것인지 결정
2. 엣지 모델과 인페인트 모델을 초기화
3. 데이터셋 불러오기(train/test에 따라 다름)
4. PSNR 등의 메트릭 초기화
5. 각종 파일(샘플 이미지, 훈련 로그, 모델 가중치) 저장 위치 설정
load(self)
프로세스 종류에 따라 특정 모델의 가중치를 로드
save(self)
프로세스 종류에 따라 특정 모델의 가중치를 세이브
train(self)
1. DataLoader 객체를 만들어 배치 단위로 train 데이터를 분할
2. 이터레이션을 돌면서 모델 훈련 진행
3. 각종 파일을 저장
eval(self)
1. DataLoader 객체를 만들어 배치 단위로 eval 데이터를 분할
2. 모델에 데이터를 넣고 메트릭 계산
3. 각종 파일을 저장
test(self)
1. DataLoader 객체를 만들어 배치 단위로 데이터를 분할
2. 모델에 데이터를 넣고 메트릭 계산
3. 각종 파일을 저장
sample(self, it=None): 샘플 이미지를 저장하는 함수
log(self, logs): 
모델 로그를 저장
cuda(self, *args): 
모델을 CUDA 디바이스로 이동
postprocess(self, img): 
텐서 이미지를 넘파이로 후처리

 

 

커스텀 데이터셋과 모델

더보기

models.py

개요 모델 네트워크를 불러와 생성자와 판별자로 구성된 모델을 생성
보조 모듈 networks.py: 모델 생성에 필요한 네트워크
loss.py: 모델 훈련에 필요한 손실 함수
클래스
  Base
Model

__init__(self, name, config)
EdgeModel과 InpaintModel이 공유하는 인스턴스 변수를 초기화
(모델 이름, 아규먼트, 이터레이션, 생성자와 판별자의 가중치 저장 경로)
load(self)
생성자: 가중치 저장 경로가 존재한다면 로드
판별자: 가중치 저장 경로가 존재한다면 로드
save(self)
생성자: 이터레이션 정보, 모델의 state_dict() 저장
판별자: 모델의 state_dict() 저장
Edge
Model
__init__(self, config)
1. 생성자와 판별자 초기화
2. 로스 초기화
3. 옵티마이저 초기화
forward(self, images, edges, masks)
이미지, 엣지, 마스크를 생성자에 집어넣고 아웃풋을 반환
backward(self, gen_loss, dis_loss)
1. 생성자와 판별자의 각각의 손실에 대한 역전파 수행
2. 생성자와 판별자의 가중치 업데이트
process(self, images, edges, masks)
1. 이터레이션 하나 올리기
2. 생성자, 판별자의 오차 0으로 초기화
3. 생성자, 판별자의 옵티마이저 zero_grad()
4. 생성자 forward 함수로 아웃풋 생성
5. 판별자 오차 계산
6. 이를 바탕으로 생성자 오차 계산
Inpaint
Model
__init__(self, config)
1. 생성자와 판별자 초기화
2. 로스 초기화
3. 옵티마이저 초기화
forward(self, images, edges, masks)
이미지, 엣지, 마스크를 생성자에 집어넣고 아웃풋을 반환
backward(self, gen_loss, dis_loss)
1. 생성자와 판별자의 각각의 손실에 대한 역전파 수행
2. 생성자와 판별자의 가중치 업데이트
process(self, images, edges, masks)
1. 이터레이션 하나 올리기
2. 생성자, 판별자의 오차 0으로 초기화
3. 생성자, 판별자의 옵티마이저 zero_grad()
4. 생성자 forward 함수로 아웃풋 생성
5. 판별자 오차 계산
6. 이를 바탕으로 생성자 오차 계산

 

dataset.py

개요 커스텀 데이터셋을 제작
보조 모듈 utils.py: 마스크를 생성하는 함수를 불러옴
클래스
Dataset
__init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True)
config 파일과 이미지, 엣지, 마스크 디렉토리를 받아옴
그리고 이에 따라 train과 test에 각각 어떻게 데이터를 불러올지 다르게 설정
__len__(self):
데이터 전체 길이 반환
__getitem__(self, index)
load_item() 함수를 통해 해당 인덱스의 아이템을 반환
load_item(self, index)
1. imread() 함수를 통해 이미지를 불러와 저장
2. 해당 이미지를 바탕으로 흑백 이미지 생성
3. load_mask() 함수를 통해 마스크를 불러와 저장
4. load_edge() 함수를 통해 엣지맵을 불러와 저장
5. 이미지, 흑백 이미지, 엣지, 마스크를 모두 텐서로 변환해 반환
load_edge(self, img, index, mask)
1. edge 옵션에 따라 canny를 쓸지, 외부 파일을 쓸지 결정한 후 반환
load_mask(self, img, index)
mask 옵션에 따라 create_mask() 함수를 이용해 어떤 마스크를 생성할지 결정한 후 반환
load_name(self, index)
파일의 이름을 반환
to_tensor(self, img): 이미지를 텐서로 변환
resize(self, img, height, width, centerCrop):
 이미지 리사이즈
load_flist(self, flist): 
데이터가 존재하는 디렉토리에서 파일을 읽어 list 형태로 반환
create_iterator(self, batch_size): 
이터레이터를 생성

 

 

모델에 직접적으로 사용될 네트워크

더보기

1. BaseNetwork는 nn.Module을 상속받아 생성한다.

BaseNetwork에는 모든 네트워크에 공통적으로 들어갈 함수가 정의되어있다.

init_func(m)은 모델의 모든 레이어를 재귀적으로 돌면서 가중치를 초기화 한다.

 

2. EdgeNetwork, InpaintNetwork, Discriminator, ResnetBlock

이 네트워크는 모두 BaseNetwork를 상속받아 만들어진다.

__init__() 생성자 함수에선 nn.Sequential() 함수를 통해 각 계층을 생성한다.

forward() 함수에선 인풋 데이터 x를 받아 각 레이어를 통과시킨 뒤 그 결과값을 반환한다.

'코드 분석 > Edge-Connect' 카테고리의 다른 글

Edge-Connect: edge_connect.py  (2) 2024.01.27
Edge-Connect: models.py  (1) 2024.01.27
Edge-Connect: dataset.py  (0) 2024.01.25
Edge-Connect: utils.py  (1) 2024.01.25
Edge-Connect: main.py  (0) 2024.01.25