-
PyTorch | 이미지 복원Deep Learning/PyTorch 2021. 6. 4. 23:59728x90
문제
weird_function() 함수는 이미지를 오염시키는 함수이다.
어떤 이미지에 weird_function() 함수를 적용 시켜 이미지가 오염되어 원본을 알아볼 수 없게 되었다.
이미지를 복원해 보자.
이미지 다운로드 링크 :접근법
- 원본 이미지와 크기가 같은 random_tensor를 생성한다.
- 원본 이미지에 weird_function() 함수를 적용한 이미지와 random_tensor에 weird_function() 함수를 적용한 이미지의 오차를 구한다.
- 오차의 값을 경사하강법을 통해 줄여나간다.
- 오차가 최소가 되었을 때 random_tensor는 원본 이미지에 가까운 이미지가 될 것이다.
이미지 복원
(마지막에 전체 소스 코드 있음!)
import
필요한 모듈을 import한다.
import torch import pickle import matplotlib.pyplot as plt from image_recovery import weird_function
werid_function()
weird_function은 이렇게 생기긴 했다. 이 코드의 내용을 이해할 필요없다.
모른다는 가정 하에 이미지를 복원 시킬것이다.
def weird_function(x, n_iter=5): h = x filt = torch.tensor([-1./3, 1./3, -1./3]) for i in range(n_iter): zero_tensor = torch.tensor([1.0*0]) h_l = torch.cat( (zero_tensor, h[:-1]), 0) h_r = torch.cat((h[1:], zero_tensor), 0 ) h = filt[0] * h + filt[2] * h_l + filt[1] * h_r if i % 2 == 0: h = torch.cat( (h[h.shape[0]//2:],h[:h.shape[0]//2]), 0 ) return h
오차 구하기
오차를 구하는 함수를 정의하자.
단순하게 hypothesis와 broken_img의 거리를 측정한 것을 오차로 둔다.
hypothesis : 원본 이미지와 같은 크기의 행렬 random_tensor를 만들고 weird_function()을 적용시킨 이미지이다.
broken_img : 원본 이미지에 werid_function()을 적용시킨 이미지이다.
def distance_loss(hypothesis, broken_image): return torch.dist(hypothesis, broken_image) # dist : 두 텐서 사이의 거리를 구하는 함수
오염된 이미지 정보
오염된 이미지의 크기를 알아냈다.
[10000] 이다.
broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1')) print(broken_image.shape) torch.Size([10000])
오염된 이미지를 출력 해보자.
plt.imshow(broken_image.view(100,100).data) plt.show()
broken_img 무작위 텐서 생성
오염된 이미지의 크기와 같은 다른 말로하면 원본 이미지의 크기와 같은 random_tensor를 생성한다.
random_tensor = torch.randn(10000, dtype=torch.float)
경사하강법
학습률은 0.8로 하고
20000번 학습 한다.
requires_grad_(True)를 설정하면 미분된 값이 .grad에 자동 저장된다.
lr = 0.8 for i in range(20000): random_tensor.requires_grad_(True) hypothesis = weird_function(random_tensor) loss = distance_loss(hypothesis, broken_image) #오차 구하기 loss.backward() #역전파 with torch.no_grad(): # 경사하강법을 직접 구현하기에 자동 기울기 계산을 비활성화 random_tensor = random_tensor - lr * random_tensor.grad if i % 1000 == 0: # 1000번 반복될 때마다 오차를 출력 print(f"Loss at {i} = {loss.item()}")
torch.Size([10000])
Loss at 0 = 12.0382661819458
Loss at 1000 = 1.1228744983673096
Loss at 2000 = 0.5494382977485657
Loss at 3000 = 0.3817315101623535
Loss at 4000 = 0.30174607038497925
Loss at 5000 = 0.25323012471199036
Loss at 6000 = 0.21860259771347046
Loss at 7000 = 0.19111904501914978
Loss at 8000 = 0.16777442395687103
Loss at 9000 = 0.14704976975917816
Loss at 10000 = 0.12809991836547852
Loss at 11000 = 0.11041469871997833
Loss at 12000 = 0.09366687387228012
Loss at 13000 = 0.07763271778821945
Loss at 14000 = 0.06215355917811394
Loss at 15000 = 0.04711384326219559
Loss at 16000 = 0.03242957592010498
Loss at 17000 = 0.02113625593483448
Loss at 18000 = 0.021164560690522194
Loss at 19000 = 0.02116667479276657복원된 이미지
plt.imshow(random_tensor.view(100,100).data) plt.show()
복원된 이미지
전체 소스 코드
import torch import pickle import matplotlib.pyplot as plt def weird_function(x, n_iter=5): h = x filt = torch.tensor([-1. / 3, 1. / 3, -1. / 3]) for i in range(n_iter): zero_tensor = torch.tensor([1.0 * 0]) h_l = torch.cat((zero_tensor, h[:-1]), 0) h_r = torch.cat((h[1:], zero_tensor), 0) h = filt[0] * h + filt[2] * h_l + filt[1] * h_r if i % 2 == 0: h = torch.cat((h[h.shape[0] // 2:], h[:h.shape[0] // 2]), 0) return h def distance_loss(hypothesis, broken_image): return torch.dist(hypothesis, broken_image) # dist : 두 텐서 사이의 거리를 구하는 함수 broken_image = torch.FloatTensor(pickle.load(open('./broken_image_t.p', 'rb'), encoding='latin1')) # 오염된 이미지 정보 # print(broken_image.shape) # plt.imshow(broken_image.view(100,100).data) # plt.show() random_tensor = torch.randn(10000, dtype=torch.float) lr = 0.8 for i in range(20000): random_tensor.requires_grad_(True) hypothesis = weird_function(random_tensor) loss = distance_loss(hypothesis, broken_image) loss.backward() with torch.no_grad(): random_tensor = random_tensor - lr * random_tensor.grad if i % 1000 == 0: print(f"Loss at {i} = {loss.item()}") plt.imshow(random_tensor.view(100,100).data) plt.show()
728x90'Deep Learning > PyTorch' 카테고리의 다른 글
PyTorch | GAN (4) 2022.01.09 PyTorch | DNN (0) 2021.06.20 Pytorch | ANN (0) 2021.06.12 Pythrch | 기초 (0) 2021.06.01 PyTorch 설치 (0) 2021.06.01