ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • PyTorch | 이미지 복원
    Deep Learning/PyTorch 2021. 6. 4. 23:59
    728x90

    문제

    weird_function() 함수는 이미지를 오염시키는 함수이다.

    어떤 이미지에 weird_function() 함수를 적용 시켜 이미지가 오염되어 원본을 알아볼 수 없게 되었다.

    이미지를 복원해 보자. 


    이미지 다운로드 링크 :

    https://github.com/keon/3-min-pytorch/raw/master/03-%ED%8C%8C%EC%9D%B4%ED%86%A0%EC%B9%98%EB%A1%9C_%EA%B5%AC%ED%98%84%ED%95%98%EB%8A%94_ANN/broken_image_t.p

     

    접근법

    1. 원본 이미지와 크기가 같은 random_tensor를 생성한다.
    2. 원본 이미지에 weird_function() 함수를 적용한 이미지와 random_tensor에 weird_function() 함수를 적용한 이미지의 오차를 구한다.
    3. 오차의 값을 경사하강법을 통해 줄여나간다.
    4. 오차가 최소가 되었을 때 random_tensor는 원본 이미지에 가까운 이미지가 될 것이다.

    이미지 복원

    (마지막에 전체 소스 코드 있음!)

    import

    필요한 모듈을 import한다.

    import torch
    import pickle
    import matplotlib.pyplot as plt
    from image_recovery import weird_function

     

    werid_function()

    weird_function은 이렇게 생기긴 했다. 이 코드의 내용을 이해할 필요없다.

    모른다는 가정 하에 이미지를 복원 시킬것이다.

    def weird_function(x, n_iter=5):
        h = x    
        filt = torch.tensor([-1./3, 1./3, -1./3])
        for i in range(n_iter):
            zero_tensor = torch.tensor([1.0*0])
            h_l = torch.cat( (zero_tensor, h[:-1]), 0)
            h_r = torch.cat((h[1:], zero_tensor), 0 )
            h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
            if i % 2 == 0:
                h = torch.cat( (h[h.shape[0]//2:],h[:h.shape[0]//2]), 0  )
        return h

     

    오차 구하기

    오차를 구하는 함수를 정의하자.

    단순하게 hypothesis와 broken_img의 거리를 측정한 것을 오차로 둔다.

    hypothesis : 원본 이미지와 같은 크기의 행렬 random_tensor를 만들고 weird_function()을 적용시킨 이미지이다.

    broken_img : 원본 이미지에 werid_function()을 적용시킨 이미지이다.

    def distance_loss(hypothesis, broken_image):
        return torch.dist(hypothesis, broken_image)  # dist : 두 텐서 사이의 거리를 구하는 함수

     

    오염된 이미지 정보

    오염된 이미지의 크기를 알아냈다.

    [10000] 이다.

    broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1'))
    print(broken_image.shape)
    torch.Size([10000])

     

    오염된 이미지를 출력 해보자.

    plt.imshow(broken_image.view(100,100).data)
    plt.show()

    broken_img

     

    무작위 텐서 생성

    오염된 이미지의 크기와 같은 다른 말로하면 원본 이미지의 크기와 같은 random_tensor를 생성한다.

    random_tensor = torch.randn(10000, dtype=torch.float)

     

    경사하강법

    학습률은 0.8로 하고

    20000번 학습 한다.

    requires_grad_(True)를 설정하면 미분된 값이 .grad에 자동 저장된다.

    lr = 0.8
    
    for i in range(20000):
        random_tensor.requires_grad_(True)
        hypothesis = weird_function(random_tensor)
        loss = distance_loss(hypothesis, broken_image)  #오차 구하기
        loss.backward()	#역전파
        with torch.no_grad():	# 경사하강법을 직접 구현하기에 자동 기울기 계산을 비활성화
            random_tensor = random_tensor - lr * random_tensor.grad
        if i % 1000 == 0:	# 1000번 반복될 때마다 오차를 출력
            print(f"Loss at {i} = {loss.item()}")
    torch.Size([10000])
    Loss at 0 = 12.0382661819458
    Loss at 1000 = 1.1228744983673096
    Loss at 2000 = 0.5494382977485657
    Loss at 3000 = 0.3817315101623535
    Loss at 4000 = 0.30174607038497925
    Loss at 5000 = 0.25323012471199036
    Loss at 6000 = 0.21860259771347046
    Loss at 7000 = 0.19111904501914978
    Loss at 8000 = 0.16777442395687103
    Loss at 9000 = 0.14704976975917816
    Loss at 10000 = 0.12809991836547852
    Loss at 11000 = 0.11041469871997833
    Loss at 12000 = 0.09366687387228012
    Loss at 13000 = 0.07763271778821945
    Loss at 14000 = 0.06215355917811394
    Loss at 15000 = 0.04711384326219559
    Loss at 16000 = 0.03242957592010498
    Loss at 17000 = 0.02113625593483448
    Loss at 18000 = 0.021164560690522194
    Loss at 19000 = 0.02116667479276657

    복원된 이미지

    plt.imshow(random_tensor.view(100,100).data)
    plt.show()

    복원된 이미지


    전체 소스 코드

    import torch
    import pickle
    import matplotlib.pyplot as plt
    
    
    def weird_function(x, n_iter=5):
        h = x
        filt = torch.tensor([-1. / 3, 1. / 3, -1. / 3])
        for i in range(n_iter):
            zero_tensor = torch.tensor([1.0 * 0])
            h_l = torch.cat((zero_tensor, h[:-1]), 0)
            h_r = torch.cat((h[1:], zero_tensor), 0)
            h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
            if i % 2 == 0:
                h = torch.cat((h[h.shape[0] // 2:], h[:h.shape[0] // 2]), 0)
        return h
    
    
    def distance_loss(hypothesis, broken_image):
        return torch.dist(hypothesis, broken_image)  # dist : 두 텐서 사이의 거리를 구하는 함수
    
    
    broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1'))
    
    # 오염된 이미지 정보
    # print(broken_image.shape) 
    # plt.imshow(broken_image.view(100,100).data)
    # plt.show()
    
    random_tensor = torch.randn(10000, dtype=torch.float)
    lr = 0.8
    
    for i in range(20000):
        random_tensor.requires_grad_(True)
        hypothesis = weird_function(random_tensor)
        loss = distance_loss(hypothesis, broken_image)
        loss.backward()
        with torch.no_grad():
            random_tensor = random_tensor - lr * random_tensor.grad
        if i % 1000 == 0:
            print(f"Loss at {i} = {loss.item()}")
    
    plt.imshow(random_tensor.view(100,100).data)
    plt.show()
    

     

    728x90

    'Deep Learning > PyTorch' 카테고리의 다른 글

    PyTorch | GAN  (4) 2022.01.09
    PyTorch | DNN  (0) 2021.06.20
    Pytorch | ANN  (0) 2021.06.12
    Pythrch | 기초  (0) 2021.06.01
    PyTorch 설치  (0) 2021.06.01

    댓글

Designed by Tistory.