코드 분석/Edge-Connect

Edge-Connect: metrics.py

상솜공방 2024. 1. 25. 12:08

코드를 이해하기 위해 필요한 지식

더보기

혼동행렬(Confusion Matrix)

혼동 행렬은 분류 모델의 성능을 설명하는 데 사용되는 표입니다. 이 표는 네 가지 구성 요소로 이루어져 있습니다:
1. 진양성 (True Positive, TP): 모델이 양성 클래스를 정확하게 예측함.
2. 거짓양성 (False Positive, FP): 모델이 양성 클래스를 잘못 예측함 (오류의 한 종류).
3. 진음성 (True Negative, TN): 모델이 음성 클래스를 정확하게 예측함.
4. 거짓음성 (False Negative, FN): 모델이 음성 클래스를 잘못 예측함 (또 다른 오류의 종류).

 

  P N
P TP FP
N FN TN

정밀도(Precision):

공식: TP / (TP + FP)

의미: 모델이 예측한 P 중에 진짜 P는 어느정도인가?

 

재현율(Recall):

공식: TP / (TP + FN)

의미: 실제 P 중에 모델이 맞춘 P는 어느정도인가?

 

self.register_buffer 사용 이유:
self.register_buffer는 PyTorch의 nn.Module에서 제공하는 메서드로, 모듈의 상태에 텐서를 저장하지만, 이 텐서는 모델의 학습 가능한 파라미터로 간주되지 않습니다.
이 방법으로 저장된 텐서는 모델의 state_dict에 포함되어 모델 저장 및 로딩 시 함께 저장되고 로드됩니다. 그러나 이 텐서는 역전파(backpropagation) 과정에서 그래디언트를 계산하거나 업데이트하지 않습니다.
self.register_buffer를 사용하는 주된 이유는 모델이 GPU로 이동할 때 이 텐서도 자동으로 같이 이동하게 하기 위함입니다. 즉, 모델이 .to(device) 메서드를 통해 다른 디바이스로 이동할 때, 등록된 버퍼도 함께 이동합니다.


호출 가능 메서드(__call__):
파이썬에서 __call__ 메서드는 객체를 함수처럼 호출할 수 있게 해주는 특수 메서드입니다.
클래스에 __call__ 메서드를 정의하면, 해당 클래스의 인스턴스(객체)를 마치 함수처럼 사용할 수 있습니다. 즉, 객체를 함수 호출과 같은 방식으로 사용하면 __call__ 메서드가 실행됩니다.
예를 들어, psnr = PSNR(max_val)로 PSNR 클래스의 인스턴스를 생성한 후, psnr(a, b)와 같이 호출하면 PSNR 클래스의 __call__ 메서드가 실행됩니다.

import torch
import torch.nn as nn


class EdgeAccuracy(nn.Module):
    """
    Measures the accuracy of the edge map
    """
    def __init__(self, threshold=0.5): # threshold는 엣지를 결정하는데 사용되는 임계값.
        super(EdgeAccuracy, self).__init__() # 상위 클래스 생성자 호출 & 변수 초기화.
        self.threshold = threshold # 인스턴스 변수로 할당.

    def __call__(self, inputs, outputs): # inputs: 실제 에지맵, outputs: 예측 에지맵.
        labels = (inputs > self.threshold) # 실제 에지맵에서 임계값보다 큰 값을 True로 변환.
        outputs = (outputs > self.threshold) # 예측 에지맵에서 임계값보다 큰 값을 True로 변환.

        relevant = torch.sum(labels.float()) # 실제 에지맵에서 True로 표시된 픽셀의 총 개수를 더한다. (TP + FN)
        selected = torch.sum(outputs.float()) # 예측 에지맵에서 True로 표시된 픽셀의 총 개수를 더한다. (TP + FP)

        if relevant == 0 and selected == 0: # 실제와 예측된 에지가 모두 없는 경우
            return torch.tensor(1), torch.tensor(1) # 정밀도와 재현율을 1로 반환.

        true_positive = ((outputs == labels) * labels).float() # TP
        recall = torch.sum(true_positive) / (relevant + 1e-8) # TP / TP + FN
        precision = torch.sum(true_positive) / (selected + 1e-8) # TP / TP + FP

        return precision, recall


class PSNR(nn.Module): # Peak Signal-to-Noise Ratio
    def __init__(self, max_val): # max_val은 신호의 최대 가능 값입니다 (예: 8비트 이미지의 경우 255).
        super(PSNR, self).__init__()

        base10 = torch.log(torch.tensor(10.0)) # 10의 자연 로그 값을 계산하여 base10에 저장합니다. 이는 로그 계산의 기반으로 사용됩니다.
        max_val = torch.tensor(max_val).float() # 입력받은 max_val을 텐서로 변환하고 부동소수점(float) 타입으로 변환합니다.

        self.register_buffer('base10', base10) # base10을 모듈의 버퍼로 등록합니다. 이는 모델의 상태에 포함되지만, 학습 과정에서는 업데이트되지 않습니다.
        self.register_buffer('max_val', 20 * torch.log(max_val) / base10) # max_val에 대한 PSNR 계산에 필요한 값을 미리 계산하여 버퍼로 등록합니다.

    def __call__(self, a, b):
        mse = torch.mean((a.float() - b.float()) ** 2) # 두 이미지 간의 평균 제곱 오차(Mean Squared Error, MSE)를 계산합니다.

        if mse == 0: # MSE가 0인 경우를 확인합니다. 이는 두 이미지가 완벽하게 동일함을 의미합니다.
            return torch.tensor(0) # MSE가 0이면 PSNR은 무한대가 되므로, 실제 계산에서는 0을 반환합니다.

        # PSNR을 계산하여 반환합니다. PSNR은 MSE의 로그 스케일 반전 값으로, 높을수록 두 이미지 간의 차이가 적음을 의미합니다.
        return self.max_val - 10 * torch.log(mse) / self.base10

'코드 분석 > Edge-Connect' 카테고리의 다른 글

Edge-Connect: utils.py  (1) 2024.01.25
Edge-Connect: main.py  (0) 2024.01.25
Edge-Connect: config.py  (1) 2024.01.24
Edge-Connect: loss.py  (0) 2024.01.24
Edge-Connect: networks.py  (1) 2024.01.24