Deep Learning/PyTorch
PyTorch | 이미지 복원
jiseong
2021. 6. 4. 23:59
728x90
문제
weird_function() 함수는 이미지를 오염시키는 함수이다.
어떤 이미지에 weird_function() 함수를 적용 시켜 이미지가 오염되어 원본을 알아볼 수 없게 되었다.
이미지를 복원해 보자.
이미지 다운로드 링크 :
접근법
- 원본 이미지와 크기가 같은 random_tensor를 생성한다.
- 원본 이미지에 weird_function() 함수를 적용한 이미지와 random_tensor에 weird_function() 함수를 적용한 이미지의 오차를 구한다.
- 오차의 값을 경사하강법을 통해 줄여나간다.
- 오차가 최소가 되었을 때 random_tensor는 원본 이미지에 가까운 이미지가 될 것이다.
이미지 복원
(마지막에 전체 소스 코드 있음!)
import
필요한 모듈을 import한다.
import torch
import pickle
import matplotlib.pyplot as plt
from image_recovery import weird_function
werid_function()
weird_function은 이렇게 생기긴 했다. 이 코드의 내용을 이해할 필요없다.
모른다는 가정 하에 이미지를 복원 시킬것이다.
def weird_function(x, n_iter=5):
h = x
filt = torch.tensor([-1./3, 1./3, -1./3])
for i in range(n_iter):
zero_tensor = torch.tensor([1.0*0])
h_l = torch.cat( (zero_tensor, h[:-1]), 0)
h_r = torch.cat((h[1:], zero_tensor), 0 )
h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
if i % 2 == 0:
h = torch.cat( (h[h.shape[0]//2:],h[:h.shape[0]//2]), 0 )
return h
오차 구하기
오차를 구하는 함수를 정의하자.
단순하게 hypothesis와 broken_img의 거리를 측정한 것을 오차로 둔다.
hypothesis : 원본 이미지와 같은 크기의 행렬 random_tensor를 만들고 weird_function()을 적용시킨 이미지이다.
broken_img : 원본 이미지에 werid_function()을 적용시킨 이미지이다.
def distance_loss(hypothesis, broken_image):
return torch.dist(hypothesis, broken_image) # dist : 두 텐서 사이의 거리를 구하는 함수
오염된 이미지 정보
오염된 이미지의 크기를 알아냈다.
[10000] 이다.
broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1'))
print(broken_image.shape)
torch.Size([10000])
오염된 이미지를 출력 해보자.
plt.imshow(broken_image.view(100,100).data)
plt.show()
무작위 텐서 생성
오염된 이미지의 크기와 같은 다른 말로하면 원본 이미지의 크기와 같은 random_tensor를 생성한다.
random_tensor = torch.randn(10000, dtype=torch.float)
경사하강법
학습률은 0.8로 하고
20000번 학습 한다.
requires_grad_(True)를 설정하면 미분된 값이 .grad에 자동 저장된다.
lr = 0.8
for i in range(20000):
random_tensor.requires_grad_(True)
hypothesis = weird_function(random_tensor)
loss = distance_loss(hypothesis, broken_image) #오차 구하기
loss.backward() #역전파
with torch.no_grad(): # 경사하강법을 직접 구현하기에 자동 기울기 계산을 비활성화
random_tensor = random_tensor - lr * random_tensor.grad
if i % 1000 == 0: # 1000번 반복될 때마다 오차를 출력
print(f"Loss at {i} = {loss.item()}")
torch.Size([10000]) Loss at 0 = 12.0382661819458 Loss at 1000 = 1.1228744983673096 Loss at 2000 = 0.5494382977485657 Loss at 3000 = 0.3817315101623535 Loss at 4000 = 0.30174607038497925 Loss at 5000 = 0.25323012471199036 Loss at 6000 = 0.21860259771347046 Loss at 7000 = 0.19111904501914978 Loss at 8000 = 0.16777442395687103 Loss at 9000 = 0.14704976975917816 Loss at 10000 = 0.12809991836547852 Loss at 11000 = 0.11041469871997833 Loss at 12000 = 0.09366687387228012 Loss at 13000 = 0.07763271778821945 Loss at 14000 = 0.06215355917811394 Loss at 15000 = 0.04711384326219559 Loss at 16000 = 0.03242957592010498 Loss at 17000 = 0.02113625593483448 Loss at 18000 = 0.021164560690522194 Loss at 19000 = 0.02116667479276657 |
복원된 이미지
plt.imshow(random_tensor.view(100,100).data)
plt.show()
전체 소스 코드
import torch
import pickle
import matplotlib.pyplot as plt
def weird_function(x, n_iter=5):
h = x
filt = torch.tensor([-1. / 3, 1. / 3, -1. / 3])
for i in range(n_iter):
zero_tensor = torch.tensor([1.0 * 0])
h_l = torch.cat((zero_tensor, h[:-1]), 0)
h_r = torch.cat((h[1:], zero_tensor), 0)
h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
if i % 2 == 0:
h = torch.cat((h[h.shape[0] // 2:], h[:h.shape[0] // 2]), 0)
return h
def distance_loss(hypothesis, broken_image):
return torch.dist(hypothesis, broken_image) # dist : 두 텐서 사이의 거리를 구하는 함수
broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1'))
# 오염된 이미지 정보
# print(broken_image.shape)
# plt.imshow(broken_image.view(100,100).data)
# plt.show()
random_tensor = torch.randn(10000, dtype=torch.float)
lr = 0.8
for i in range(20000):
random_tensor.requires_grad_(True)
hypothesis = weird_function(random_tensor)
loss = distance_loss(hypothesis, broken_image)
loss.backward()
with torch.no_grad():
random_tensor = random_tensor - lr * random_tensor.grad
if i % 1000 == 0:
print(f"Loss at {i} = {loss.item()}")
plt.imshow(random_tensor.view(100,100).data)
plt.show()
728x90