Deep Learning/PyTorch

PyTorch | 이미지 복원

jiseong 2021. 6. 4. 23:59
728x90

문제

weird_function() 함수는 이미지를 오염시키는 함수이다.

어떤 이미지에 weird_function() 함수를 적용 시켜 이미지가 오염되어 원본을 알아볼 수 없게 되었다.

이미지를 복원해 보자. 


이미지 다운로드 링크 :

https://github.com/keon/3-min-pytorch/raw/master/03-%ED%8C%8C%EC%9D%B4%ED%86%A0%EC%B9%98%EB%A1%9C_%EA%B5%AC%ED%98%84%ED%95%98%EB%8A%94_ANN/broken_image_t.p

 

접근법

  1. 원본 이미지와 크기가 같은 random_tensor를 생성한다.
  2. 원본 이미지에 weird_function() 함수를 적용한 이미지와 random_tensor에 weird_function() 함수를 적용한 이미지의 오차를 구한다.
  3. 오차의 값을 경사하강법을 통해 줄여나간다.
  4. 오차가 최소가 되었을 때 random_tensor는 원본 이미지에 가까운 이미지가 될 것이다.

이미지 복원

(마지막에 전체 소스 코드 있음!)

import

필요한 모듈을 import한다.

import torch
import pickle
import matplotlib.pyplot as plt
from image_recovery import weird_function

 

werid_function()

weird_function은 이렇게 생기긴 했다. 이 코드의 내용을 이해할 필요없다.

모른다는 가정 하에 이미지를 복원 시킬것이다.

def weird_function(x, n_iter=5):
    h = x    
    filt = torch.tensor([-1./3, 1./3, -1./3])
    for i in range(n_iter):
        zero_tensor = torch.tensor([1.0*0])
        h_l = torch.cat( (zero_tensor, h[:-1]), 0)
        h_r = torch.cat((h[1:], zero_tensor), 0 )
        h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
        if i % 2 == 0:
            h = torch.cat( (h[h.shape[0]//2:],h[:h.shape[0]//2]), 0  )
    return h

 

오차 구하기

오차를 구하는 함수를 정의하자.

단순하게 hypothesis와 broken_img의 거리를 측정한 것을 오차로 둔다.

hypothesis : 원본 이미지와 같은 크기의 행렬 random_tensor를 만들고 weird_function()을 적용시킨 이미지이다.

broken_img : 원본 이미지에 werid_function()을 적용시킨 이미지이다.

def distance_loss(hypothesis, broken_image):
    return torch.dist(hypothesis, broken_image)  # dist : 두 텐서 사이의 거리를 구하는 함수

 

오염된 이미지 정보

오염된 이미지의 크기를 알아냈다.

[10000] 이다.

broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1'))
print(broken_image.shape)
torch.Size([10000])

 

오염된 이미지를 출력 해보자.

plt.imshow(broken_image.view(100,100).data)
plt.show()

broken_img

 

무작위 텐서 생성

오염된 이미지의 크기와 같은 다른 말로하면 원본 이미지의 크기와 같은 random_tensor를 생성한다.

random_tensor = torch.randn(10000, dtype=torch.float)

 

경사하강법

학습률은 0.8로 하고

20000번 학습 한다.

requires_grad_(True)를 설정하면 미분된 값이 .grad에 자동 저장된다.

lr = 0.8

for i in range(20000):
    random_tensor.requires_grad_(True)
    hypothesis = weird_function(random_tensor)
    loss = distance_loss(hypothesis, broken_image)  #오차 구하기
    loss.backward()	#역전파
    with torch.no_grad():	# 경사하강법을 직접 구현하기에 자동 기울기 계산을 비활성화
        random_tensor = random_tensor - lr * random_tensor.grad
    if i % 1000 == 0:	# 1000번 반복될 때마다 오차를 출력
        print(f"Loss at {i} = {loss.item()}")
torch.Size([10000])
Loss at 0 = 12.0382661819458
Loss at 1000 = 1.1228744983673096
Loss at 2000 = 0.5494382977485657
Loss at 3000 = 0.3817315101623535
Loss at 4000 = 0.30174607038497925
Loss at 5000 = 0.25323012471199036
Loss at 6000 = 0.21860259771347046
Loss at 7000 = 0.19111904501914978
Loss at 8000 = 0.16777442395687103
Loss at 9000 = 0.14704976975917816
Loss at 10000 = 0.12809991836547852
Loss at 11000 = 0.11041469871997833
Loss at 12000 = 0.09366687387228012
Loss at 13000 = 0.07763271778821945
Loss at 14000 = 0.06215355917811394
Loss at 15000 = 0.04711384326219559
Loss at 16000 = 0.03242957592010498
Loss at 17000 = 0.02113625593483448
Loss at 18000 = 0.021164560690522194
Loss at 19000 = 0.02116667479276657

복원된 이미지

plt.imshow(random_tensor.view(100,100).data)
plt.show()

복원된 이미지


전체 소스 코드

import torch
import pickle
import matplotlib.pyplot as plt


def weird_function(x, n_iter=5):
    h = x
    filt = torch.tensor([-1. / 3, 1. / 3, -1. / 3])
    for i in range(n_iter):
        zero_tensor = torch.tensor([1.0 * 0])
        h_l = torch.cat((zero_tensor, h[:-1]), 0)
        h_r = torch.cat((h[1:], zero_tensor), 0)
        h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
        if i % 2 == 0:
            h = torch.cat((h[h.shape[0] // 2:], h[:h.shape[0] // 2]), 0)
    return h


def distance_loss(hypothesis, broken_image):
    return torch.dist(hypothesis, broken_image)  # dist : 두 텐서 사이의 거리를 구하는 함수


broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1'))

# 오염된 이미지 정보
# print(broken_image.shape) 
# plt.imshow(broken_image.view(100,100).data)
# plt.show()

random_tensor = torch.randn(10000, dtype=torch.float)
lr = 0.8

for i in range(20000):
    random_tensor.requires_grad_(True)
    hypothesis = weird_function(random_tensor)
    loss = distance_loss(hypothesis, broken_image)
    loss.backward()
    with torch.no_grad():
        random_tensor = random_tensor - lr * random_tensor.grad
    if i % 1000 == 0:
        print(f"Loss at {i} = {loss.item()}")

plt.imshow(random_tensor.view(100,100).data)
plt.show()

 

728x90