코드 분석/Edge-Connect

Edge-Connect: dataset.py

상솜공방 2024. 1. 25. 17:39

코드를 이해하기 위해 필요한 지식

더보기

특수 메서드 __len__과 __getitem__
__len__ 메서드:
__len__은 컨테이너 타입의 객체(예: 리스트, 튜플, 세트, 사용자 정의 컬렉션 등)의 길이를 반환하는 특수 메서드입니다.
예를 들어, len(obj)를 호출하면 내부적으로 obj.__len__()이 호출됩니다.
PyTorch의 Dataset 클래스에서 __len__은 데이터셋의 총 항목 수를 반환합니다.


__getitem__ 메서드:
__getitem__은 컨테이너 타입의 객체에서 특정 인덱스의 항목을 가져오는 특수 메서드입니다.
예를 들어, obj[index]를 호출하면 내부적으로 obj.__getitem__(index)가 호출됩니다.
PyTorch의 Dataset 클래스에서 __getitem__은 지정된 인덱스의 데이터 항목(예: 이미지와 레이블)을 반환합니다. 이는 DataLoader 클래스에서도 쓰이므로 커스터마이징 할 때 반드시 수정해야 하는 부분입니다.

 

mask = None if self.training else (1 - mask / 255).astype(np.bool)
이 코드는 훈련 모드와 테스트 모드에 따라 마스크의 처리 방식을 다르게 합니다. 예를 들어:
훈련 모드(self.training == True): mask는 None으로 설정됩니다. 이는 마스크를 사용하지 않음을 의미합니다.
테스트 모드(self.training == False): mask는 (1 - mask / 255)로 계산되어 이진 마스크로 변환됩니다. 예를 들어, mask 배열이 [255, 0, 255]라면, 이 코드는 [0, 1, 0]으로 변환되고, np.bool로 형변환하여 [False, True, False]가 됩니다.

 

imgh, imgw = img.shape[0:2]
이 코드는 이미지의 높이(imgh)와 너비(imgw)를 추출합니다. img.shape는 이미지의 차원을 나타내며, [높이, 너비, 채널] 형태를 가집니다. img.shape[0:2]는 첫 번째와 두 번째 요소(높이와 너비)를 가져옵니다.

 

mask_type 변수 사용 이유
mask_type 변수는 self.mask의 값을 복사하여 로컬 변수로 사용하는 것입니다. 이는 코드의 가독성을 높이고, self.mask 값이 함수 실행 도중 변경되는 것을 방지하기 위함일 수 있습니다. 또한, mask_type을 수정하거나 다른 값으로 재할당하는 경우, 원본 self.mask에 영향을 주지 않기 위해 사용됩니다.

mask = (mask > 0).astype(np.uint8) * 255
이 코드는 마스크를 이진 마스크(binary mask)로 변환합니다.
mask > 0: 마스크에서 0보다 큰 값은 True, 그렇지 않은 값은 False로 변환됩니다.
.astype(np.uint8): 불리언 값을 0과 1의 정수로 변환합니다.
* 255: 1을 255로 변환하여 일반적인 이미지 마스크 형식(0과 255)으로 만듭니다.
예를 들어, 마스크가 [0, 100, 200]라면, 이 코드는 [0, 255, 255]로 변환됩니다.

import os # 파일 시스템 경로와 관련된 기능을 제공하는 표준 라이브러리
import glob #  파일 경로명을 확장하는 데 사용되는 glob 모듈. 이는 특정 패턴과 일치하는 모든 파일 경로를 찾는 데 사용
import scipy # 과학계산
import torch # 파이토치
import random # 난수 생성
import numpy as np # 넘파이
import torchvision.transforms.functional as F # 이미지 변환 기능
from torch.utils.data import DataLoader # 전체 데이터셋을 미니 배치로 로드하는데 사용
from PIL import Image # 이미지 가공
from scipy.misc import imread # 이미지를 읽는 함수
from skimage.feature import canny # 캐니 필터
from skimage.color import rgb2gray, gray2rgb # 흑백 <=> 컬러
from .utils import create_mask # 사각형 마스크 생성기


class Dataset(torch.utils.data.Dataset):
    def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True):
        # config: 다양한 모델 옵션을 딕셔너리로 저장한 객체
        # flist, edge_flist, mask_flist: 이미지, 엣지, 마스크 리스트
        # augment: 데이터 증강 여부
        # training: 훈련 모드 여부
        super(Dataset, self).__init__()
        self.augment = augment # True
        self.training = training # True

        self.data = self.load_flist(flist) # 이미지 데이터 리스트를 저장
        self.edge_data = self.load_flist(edge_flist) # 외부의 엣지맵 리스트를 저장
        self.mask_data = self.load_flist(mask_flist) # 마스크 리스트를 저장
        # load_flist()는 데이터를 정렬해서 저장하므로 하단의 load_item() 함수에서 페어끼리 같이 불러올 수 있다.

        self.input_size = config.INPUT_SIZE # 이미지 해상도
        self.sigma = config.SIGMA # 캐니 엣지 디텍터에 쓰이는 가우시안 필터의 표준 편차
        self.edge = config.EDGE # 1.canny, 2.external
        self.mask = config.MASK # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
        self.nms = config.NMS # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny

        # in test mode, there's a one-to-one relationship between mask and image
        # masks are loaded non random
        if config.MODE == 2:
            self.mask = 6 # 6번 옵션은 마스크가 랜덤하게 선택되지 않고, index에 따라 이미지와 하나의 페어로 생성된다.


    # 데이터셋의 총 길이를 반환하는 특수 메서드
    def __len__(self):
        return len(self.data)


    # 인덱스에 해당하는 이미지, 회색조 이미지, 엣지맵, 마스크를 모두 텐서로 변환해 가져온 뒤 반환.
    def __getitem__(self, index):
        try:
            item = self.load_item(index)
        except:
            print('loading error: ' + self.data[index]) # 로드에 실패한 파일의 경로를 출력
            item = self.load_item(0) # 데이터셋의 첫 번째 항목을 대신 로드
        return item


    def load_name(self, index):
        name = self.data[index] # 클래스의 data 리스트에서 지정된 인덱스에 해당하는 요소(파일 경로)를 name 변수에 할당 ex) "./data/images/sample.jpg"
        return os.path.basename(name) # 전체 파일 경로에서 기본 파일 이름만을 추출하여 반환 ex) "sample.jpg"


    def load_item(self, index): # 인덱스를 받아 그에 대한 아이템을 받아오는 것.

        size = self.input_size

        # load image
        img = imread(self.data[index])
        # gray to rgb
        if len(img.shape) < 3:
            img = gray2rgb(img)
        # resize/crop if needed
        if size != 0: # 원본 이미지 크기를 사용하는 게 아니면
            img = self.resize(img, size, size)

        # create grayscale image
        img_gray = rgb2gray(img)

        # load mask
        mask = self.load_mask(img, index)

        # load edge
        edge = self.load_edge(img_gray, index, mask)

        # augment data
        if self.augment and np.random.binomial(1, 0.5) > 0: # 데이터 증강(augmentation)이 활성화되어 있고, 50% 확률로 조건이 충족되면 이미지를 수평으로 뒤집는다.
            # 이미지, 회색조 이미지, 에지, 마스크를 수평으로 뒤집는다.
            img = img[:, ::-1, ...]
            img_gray = img_gray[:, ::-1, ...]
            edge = edge[:, ::-1, ...]
            mask = mask[:, ::-1, ...]
        # 처리된 이미지, 회색조 이미지, 에지, 마스크를 텐서로 변환하여 반환
        return self.to_tensor(img), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask)


    # 이미지의 에지(edge)를 로드하거나 생성하는 역할.
    # 훈련과 테스트 모드에 따라 다르게 동작하며, Canny 에지 검출 알고리즘 또는 외부에서 제공된 에지 정보를 사용.
    def load_edge(self, img, index, mask): # img: 처리할 이미지, index: 이미지의 인덱스, mask: 이미지의 마스크.
        sigma = self.sigma # Canny 에지 검출에 사용될 sigma 값을 클래스 인스턴스 변수에서 가져

        # in test mode images are masked (with masked regions), 테스트 모드에서 이미지는 마스크에 의해 가려진다.
        # using 'mask' parameter prevents canny to detect edges for the masked regions. mask 인자를 이용하여 가려진 부분의 엣지를 canny 필터에 씌우지 않는다.
        mask = None if self.training else (1 - mask / 255).astype(np.bool)

        # self.edge가 1인 경우 canny 사용
        if self.edge == 1:
            # sigma가 -1인 경우 엣지 생성을 하지 않는다.
            if sigma == -1:
                return np.zeros(img.shape).astype(np.float)

            # random sigma
            if sigma == 0:
                sigma = random.randint(1, 4)
            # 이 경우에는 args로 들어간 엣지 이미지 데이터가 필요가 없어진다.
            return canny(img, sigma=sigma, mask=mask).astype(np.float)

        # external
        else:
            imgh, imgw = img.shape[0:2] # 이미지의 높이와 너비 추출
            edge = imread(self.edge_data[index]) # 외부 엣지맵을 불러온다.
            edge = self.resize(edge, imgh, imgw) # 엣지맵을 리사이즈

            # non-max suppression
            if self.nms == 1: # NMS를 적용하기 위해 Canny 에지 검출 결과와 외부 에지 정보를 결합합니다.
                edge = edge * canny(img, sigma=sigma, mask=mask)

            return edge


    def load_mask(self, img, index): # img: 처리할 이미지, index: 이미지의 인덱스.
        imgh, imgw = img.shape[0:2]
        mask_type = self.mask

        # external + random block
        if mask_type == 4: # mask_type == 4는 외부 마스크와 랜덤 블록 마스크 중 하나를 무작위로 선택
            mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3

        # external + random block + half 외부 마스크, 랜덤 블록 마스크, 반 마스크 중 하나를 무작위로 선택
        elif mask_type == 5:
            mask_type = np.random.randint(1, 4)

        # random block 이미지의 절반 크기의 마스크를 생성
        if mask_type == 1:
            return create_mask(imgw, imgh, imgw // 2, imgh // 2)

        # half
        if mask_type == 2:
            # randomly choose right or left 이미지의 오른쪽 또는 왼쪽 절반에 마스크를 생성
            return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)

        # external
        if mask_type == 3:
            mask_index = random.randint(0, len(self.mask_data) - 1) # 무작위로 마스크 인덱스 생성
            mask = imread(self.mask_data[mask_index]) # 마스크 불러오기
            mask = self.resize(mask, imgh, imgw) # 크기 조절
            mask = (mask > 0).astype(np.uint8) * 255       # threshold due to interpolation
            return mask

        # test mode: load mask non random
        if mask_type == 6:
            mask = imread(self.mask_data[index]) # 지정된 인덱스의 마스크 파일을 로드
            mask = self.resize(mask, imgh, imgw, centerCrop=False) # 크기 조절
            mask = rgb2gray(mask) # 회색조 이미지로 변환
            mask = (mask > 0).astype(np.uint8) * 255 # 이진 마스크로 변환
            return mask


    def to_tensor(self, img):
        img = Image.fromarray(img) # NumPy 배열을 PIL 이미지로 변환. (PyTorch의 텐서 변환을 위해 필요)
        img_t = F.to_tensor(img).float() # 함수를 사용하여 PIL 이미지를 PyTorch 텐서로 변환하고, 데이터 타입을 부동소수점(float)으로 설정.
        return img_t


    def resize(self, img, height, width, centerCrop=True): # img: NumPy 배열 형태의 이미지 height, width: 조정할 크기
        imgh, imgw = img.shape[0:2]

        if centerCrop and imgh != imgw:
            # center crop
            side = np.minimum(imgh, imgw) # 높이와 너비 중 더 짧은 것을 기준으로 자른다.
            j = (imgh - side) // 2
            i = (imgw - side) // 2
            img = img[j:j + side, i:i + side, ...]

        img = scipy.misc.imresize(img, [height, width]) # 이미지 크기 조정.

        return img


    def load_flist(self, flist):
        if isinstance(flist, list): # flist가 list인지 확인
            return flist # 맞으면 flist를 그대로 반환하고 함수 종료.

        # flist: image file path, image directory path, text file flist path
        if isinstance(flist, str): # flist가 string인지 확인
            if os.path.isdir(flist): # 그리고 디렉토리 경로인지 확인
                flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) # 디렉토리의 모든 이미지를 저장
                flist.sort() # 정렬해서
                return flist # 반환한다

            if os.path.isfile(flist): # flist가 파일 경로인지 확인
                try: # 파일을 읽으려고 시도
                    return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
                    # genfromtxt: 텍스트 파일의 각 줄(이미지 파일 경로)을 문자열(np.str)로 읽고 utf-8로 인코딩 한다.
                except:
                    return [flist] # 텍스트 파일이 아니라면 그냥 파일 경로 자체를 리스트로 감싸 반환. (이미지가 하나인 경우)

        return [] # flist가 유효한 형식이 아닌 경우 단순히 빈 리스트를 반환.


    # 데이터셋에 대한 이터레이터(iterator)를 생성하는 데 사용
    # 무한 루프 내에서 데이터를 배치(batch) 단위로 로드하고, 각 배치를 순차적으로 반환
    def create_iterator(self, batch_size):
        while True: # 무한 루프 시작
            # DataLoader는 전체 데이터에서 배치 사이즈만큼의 데이터만 담아 온다.
            sample_loader = DataLoader(
                dataset=self,
                batch_size=batch_size,
                drop_last=True
            )
            # 그 후 데이터로더를 순회하며 아이템을 가져온다.
            # yield를 쓰는 이유는, 제너레이터를 사용하는 외부 코드가 더이상 next()를 호출하지 않게 하여 무한 루프를 빠져나가기 위함이다.
            for item in sample_loader:
                yield item

'코드 분석 > Edge-Connect' 카테고리의 다른 글

Edge-Connect: edge_connect.py  (2) 2024.01.27
Edge-Connect: models.py  (1) 2024.01.27
Edge-Connect: utils.py  (1) 2024.01.25
Edge-Connect: main.py  (0) 2024.01.25
Edge-Connect: metrics.py  (0) 2024.01.25