티스토리 뷰

코드 분석/SGGLAT

dataset.py

상솜공방 2024. 2. 13. 12:12
import torch
import random
import numpy as np
import os, cv2, glob


class Dataset(torch.utils.data.Dataset):
    def __init__(self, args, image_flist, edge_flist, mask_flist, augment=True, training=True):
        # args: 다양한 모델 옵션을 딕셔너리로 저장한 객체
        # image_flist, edge_flist, mask_flist: 이미지, 엣지, 마스크 파일이 들어있는 폴더까지의 경로
        # augment: 데이터 증강 여부
        # training: 훈련 모드 여부
        super(Dataset, self).__init__()
        self.augment = augment # True
        self.training = training # True

        self.image_flist = self.load_flist(image_flist) # 이미지 데이터 리스트를 저장
        self.edge_flist = self.load_flist(edge_flist) # 외부의 엣지맵 리스트를 저장
        self.mask_flist = self.load_flist(mask_flist) # 마스크 리스트를 저장
        # load_flist()는 데이터를 정렬해서 저장하므로 하단의 load_item() 함수에서 페어끼리 같이 불러올 수 있다.

        self.input_size = args.input_size # 이미지 해상도
        self.sigma = args.sigma # 캐니 엣지 디텍터에 쓰이는 가우시안 필터의 표준 편차
        self.edge = args.edge # 1.canny, 2.external
        self.mask = args.mask # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
        self.patch_size = args.patch_size # 마스크를 자를 때 쓰는 패치 변의 길이
        #self.nms = args. # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny

        # in test mode, there's a one-to-one relationship between mask and image
        # masks are loaded non random
        if args.mode == 2:
            self.mask = 6 # 6번 옵션은 마스크가 랜덤하게 선택되지 않고, index에 따라 이미지와 하나의 페어로 생성된다.


    # 데이터셋의 총 길이를 반환하는 특수 메서드
    def __len__(self):
        return len(self.image_flist)


    # 인덱스에 해당하는 이미지, 회색조 이미지, 엣지맵, 마스크를 모두 텐서로 변환해 가져온 뒤 반환.
    def __getitem__(self, index):
        item = self.load_item(index)
        return item


    def load_name(self, index):
        name = self.image_flist[index] # 클래스의 data 리스트에서 지정된 인덱스에 해당하는 요소(파일 경로)를 name 변수에 할당 ex) "./data/images/sample.jpg"
        return os.path.basename(name) # 전체 파일 경로에서 기본 파일 이름만을 추출하여 반환 ex) "sample.jpg"


    def load_item(self, index): # 인덱스를 받아 그에 대한 아이템을 받아오는 것.
        patch_size = self.patch_size

        # load image
        img = cv2.imread(self.image_flist[index])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # create grayscale image
        img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

        # load mask
        mask = self.load_mask(img, index)

        # load edge
        edge = self.load_edge(img_gray, index, mask)

        # augment data
        if self.augment and np.random.binomial(1, 0.5) > 0: # 데이터 증강(augmentation)이 활성화되어 있고, 50% 확률로 조건이 충족되면 이미지를 수평으로 뒤집는다.
            # 이미지, 회색조 이미지, 에지, 마스크를 수평으로 뒤집는다.
            img = img[:, ::-1, ...]
            img_gray = img_gray[:, ::-1, ...]
            edge = edge[:, ::-1, ...]
            mask = mask[:, ::-1, ...]
        # 처리된 이미지, 회색조 이미지, 에지, 마스크를 텐서로 변환하여 반환
            
        img_t = self.to_tensor(img, 'color_image')
        img_gray_t = self.to_tensor(img_gray, 'gray_image')
        edge_t = self.to_tensor(edge, 'canny_image')
        mask_t = self.to_tensor(mask, 'mask')
        mask_t = self.preprocess_mask_image(mask_t, patch_size)
            
        return img_t, img_gray_t, edge_t, mask_t

    # 이미지의 에지(edge)를 로드하거나 생성하는 역할.
    # 훈련과 테스트 모드에 따라 다르게 동작하며, Canny 에지 검출 알고리즘 또는 외부에서 제공된 에지 정보를 사용.
    def load_edge(self, img, index, mask): # img: 처리할 이미지, index: 이미지의 인덱스, mask: 이미지의 마스크.
        sigma = self.sigma # Canny 에지 검출에 사용될 sigma 값을 클래스 인스턴스 변수에서 가져

        # in test mode images are masked (with masked regions), 테스트 모드에서 이미지는 마스크에 의해 가려진다.
        # using 'mask' parameter prevents canny to detect edges for the masked regions. mask 인자를 이용하여 가려진 부분의 엣지를 canny 필터에 씌우지 않는다.
        mask = None if self.training else (1 - mask / 255).astype(np.bool) # 아마 여기서 test 시점에 오류가 날 것 같다.

        # self.edge가 1인 경우 canny 사용
        if self.edge == 1:
            # sigma가 -1인 경우 엣지 생성을 하지 않는다.
            if sigma == -1:
                return np.zeros(img.shape).astype(np.float32)

            # random sigma
            if sigma == 0:
                sigma = random.randint(1, 4)
                
            img_canny = cv2.Canny(img, 100, 300)
            return img_canny
        
        # # external
        # else:
        #     imgh, imgw = img.shape[0:2] # 이미지의 높이와 너비 추출
        #     edge = cv2.imread(self.edge_data[index]) # 외부 엣지맵을 불러온다.
        #     edge = self.resize(edge, imgh, imgw) # 엣지맵을 리사이즈
        #     # non-max suppression
        #     if self.nms == 1: # NMS를 적용하기 위해 Canny 에지 검출 결과와 외부 에지 정보를 결합합니다.
        #         edge = edge * canny(img, sigma=sigma, mask=mask)
        #     return edge


    def load_mask(self, img, index): # img: 처리할 이미지, index: 이미지의 인덱스.
        imgh, imgw = img.shape[0:2]
        mask_type = self.mask

        # external
        if mask_type == 3:
            mask_index = random.randint(0, len(self.mask_flist) - 1) # 무작위로 마스크 인덱스 생성
            mask = cv2.imread(self.mask_flist[mask_index], cv2.IMREAD_GRAYSCALE)
        
        # test mode: load mask non random
        elif mask_type == 6:
            mask = cv2.imread(self.mask_flist[index], cv2.IMREAD_GRAYSCALE)
            
        return mask

    def to_tensor(self, img, type):
        img= img.astype(np.float32) / 255.0
        if type == 'color_image':
            img_t = torch.from_numpy(img).permute(2, 0, 1)
        elif type == 'gray_image':
            img_t = torch.from_numpy(img).unsqueeze(0) # 채널 차원을 추가 (1, H, W)
        elif type == 'canny_image':
            img_t = torch.from_numpy(img).unsqueeze(0) # 채널 차원을 추가 (1, H, W)
        elif type == 'mask':
            img_t = torch.from_numpy(img).unsqueeze(0) # 채널 차원을 추가 (1, H, W)
            img_t = torch.round(img_t) # 혹시 0과 1 사이의 값이 있을 수 있으니, 반올림하여 0 혹은 1로 매핑
            #img_t = 1 - img_t # 0과 1을 반전
        return img_t


    def load_flist(self, flist):
        if os.path.isdir(flist):
            file_list = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
            file_list.sort()
            return file_list
        else:
            return []

    
    def preprocess_mask_image(self, mask, patch_size): # mask (torch.Tensor): (1, H, W) tensor, patch_size (int): 패치의 크기 P
        # 이미지를 패치로 나눕니다.
        patches = mask.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size) # (1, nH, nW, P, P)
        C, nH, nW, P, _ = patches.size()
        patches = patches.contiguous().view(C, nH*nW, P, P) # (C, nH*nW, P, P)
        reconstructed_image = torch.zeros((C, nH*patch_size, nW*patch_size)) # 전체 이미지를 위한 빈 텐서를 생성 (C, H, W)
        
        # 패치 순회
        for i in range(nH):
            for j in range(nW):
                patch = patches[:, i * nW + j, :, :] # (i * nW + j)는 0부터 nH*nW까지 순차적으로 1씩 증가, patch = (C, P, P)
                if torch.any(patch == 1): # 1이 하나라도 있는 텐서는 모두 1으로 마스킹
                    patch = torch.ones_like(patch)
                reconstructed_image[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = patch.reshape(C, P, P)
                
        return reconstructed_image # 전처리된 이미지 텐서, 형태는 (1, 256, 256)

'코드 분석 > SGGLAT' 카테고리의 다른 글

dataset.ipynb  (0) 2024.02.13
SGGLAT_G.ipynb  (0) 2024.02.13
SGGLAT_D.py  (0) 2024.02.13
모델 구조도  (0) 2024.02.13
SGGLAT_G.py  (0) 2024.02.12
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
TAG
more
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
글 보관함