2017년 7월 10일 월요일

딥러닝 GAN 기반 image-to-image 모델

이 글은 딥러닝 GAN(Generative Adversarial Network. 생성 대립 신경망)모델에 기반한 image-to-image를 소개한다. 관련 논문은 다음 링크를 참고한다. 신경망의 종류와 개념 소개에 대해서는 여기를 참고한다.
소개
이 글은 조건적 adversarial 네트워크를 이용한 GAN에 기반해, 이미지-이미지 생성 문제에 대한 범용 솔루션을 제안한다. 이러한 네트워크는 입력 이미지에서 출력 이미지로의 매핑을 학습할 수 있다. 이 접근법은 레이블 맵에서 사진을 합성하고, 가장자리 맵에서 객체를 재구성하고, 다른 작업 중에서 이미지를 페인팅하는 데 효과적이다. 

GAN기반 image-to-image

내용
GAN은 랜덤 노이즈 벡터 z에서 출력 이미지 y : G : z → y 로의 매핑을 학습하는 생성 모델이다. 대조적으로 조건부 GAN은 관찰된 이미지 x와 랜덤 노이즈 벡터 z에서 y : G : {x, z} → y 로의 매핑을 학습한다. 생성기 G는 "가짜"이미지 감지 시, 훈련된 discrimintor D에 의해 "실제"이미지와 유사한 이미지를 출력하도록 훈련된다. 다음 그림은 훈련과정을 보여준다.

Training conditional GAN

G는 목적함수를 최소화하기 위해 반복 훈련하다. 이 목적함수는 D가 최대화하려는 것에 비해 대조된다. 

G∗ = arg minG maxD LcGAN (G, D)
LGAN(G, D)=Ey∼pdata(y) [log D(y)] + Ex∼pdata(x),z∼pz(z)[log(1 − D(G(x, z))]

다음은 이를 통해 훈련된 image-image net을 이용하여, 이미지를 생성한 예이다.
손실 차이에 따른 결과의 품질

코딩
다음은 GAN 아키텍처를 코딩한 주피터 노트북 코드이다.
from google.colab import drive
drive.mount('/content/drive')

!cp "/content/drive/MyDrive/Colab Notebooks/data/img_align_celeba.zip" "."
!unzip "./img_align_celeba.zip" -d "./GAN/"

import glob
import matplotlib.pyplot as plt
import os

from PIL import Image

# 이미지까지의 경로
pth_to_imgs = "./GAN/img_align_celeba"
imgs = glob.glob(os.path.join(pth_to_imgs, "*"))


# 9개의 이미지를 보여줌
for i in range(9):
   plt.subplot(3, 3, i+1)
   img = Image.open(imgs[i])
   plt.imshow(img)

plt.show()

import torch
import torchvision.transforms as tf

from torchvision.datasets import ImageFolder
from torch.utils.data.dataloader import DataLoader


# 이미지의 전처리 과정
transforms = tf.Compose([
   tf.Resize(64),
   tf.CenterCrop(64),
   tf.ToTensor(),
   tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


# ImageFolder()를 이용해 데이터셋을 작성
# root는 최상위 경로를, transform은 전처리를 의미합니다.
dataset = ImageFolder(
   root="./GAN",
   transform=transforms
)
loader = DataLoader(dataset, batch_size=128, shuffle=True)

import torch.nn as nn


class Generator(nn.Module):
   def __init__(self):
       super(Generator, self).__init__()
      
       # 생성자를 구성하는 층 정의
       self.gen = nn.Sequential(
           nn.ConvTranspose2d(100, 512, kernel_size=4, bias=False),
           nn.BatchNorm2d(512),
           nn.ReLU(),

           nn.ConvTranspose2d(512, 256, kernel_size=4, 
                              stride=2, padding=1, bias=False),
           nn.BatchNorm2d(256),
           nn.ReLU(),

           nn.ConvTranspose2d(256, 128, kernel_size=4, 
                              stride=2, padding=1, bias=False),
           nn.BatchNorm2d(128),
           nn.ReLU(),

           nn.ConvTranspose2d(128, 64, kernel_size=4, 
                              stride=2, padding=1, bias=False),
           nn.BatchNorm2d(64),
           nn.ReLU(),

           nn.ConvTranspose2d(64, 3, kernel_size=4, 
                              stride=2, padding=1, bias=False),
           nn.Tanh()
       )

   def forward(self, x):
       return self.gen(x)

class Discriminator(nn.Module):
   def __init__(self):
       super(Discriminator, self).__init__()
      
       # 감별자를 구성하는 층의 정의
       self.disc = nn.Sequential(
           nn.Conv2d(3, 64, kernel_size=4, 
                     stride=2, padding=1, bias=False),
           nn.BatchNorm2d(64),
           nn.LeakyReLU(0.2),

           nn.Conv2d(64, 128, kernel_size=4, 
                     stride=2, padding=1, bias=False),
           nn.BatchNorm2d(128),
           nn.LeakyReLU(0.2),

           nn.Conv2d(128, 256, kernel_size=4, 
                     stride=2, padding=1, bias=False),
           nn.BatchNorm2d(256),
           nn.LeakyReLU(0.2),

           nn.Conv2d(256, 512, kernel_size=4, 
                     stride=2, padding=1, bias=False),
           nn.BatchNorm2d(512),
           nn.LeakyReLU(0.2),

           nn.Conv2d(512, 1, kernel_size=4),
           nn.Sigmoid()
       )

   def forward(self, x):
       return self.disc(x)

def weights_init(m):
   # 층의 종류 추출
   classname = m.__class__.__name__
   if classname.find('Conv') != -1:
       # 합성곱층 초기화
       nn.init.normal_(m.weight.data, 0.0, 0.02)
   elif classname.find('BatchNorm') != -1:
       # 배치정규화층 초기화
       nn.init.normal_(m.weight.data, 1.0, 0.02)
       nn.init.constant_(m.bias.data, 0)

device = "cuda" if torch.cuda.is_available() else "cpu"

# 생성자 정의
G = Generator().to(device)
# 생성자 가중치 초기화
G.apply(weights_init)

# 감별자 정의
D = Discriminator().to(device)
# 감별자 가중치 초기화
D.apply(weights_init)

import tqdm
from torch.optim.adam import Adam

G_optim = Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))
D_optim = Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))

for epochs in range(50):
   iterator = tqdm.tqdm(enumerate(loader, 0), total=len(loader))

   for i, data in iterator:
       D_optim.zero_grad()
      
       # 실제 이미지에는 1, 생성된 이미지는 0으로 정답을 설정
       label = torch.ones_like(
           data[1], dtype=torch.float32).to(device)
       label_fake = torch.zeros_like(
           data[1], dtype=torch.float32).to(device)
      
       # 실제 이미지를 감별자에 입력
       real = D(data[0].to(device))
      
       # 실제 이미지에 대한 감별자의 오차를 계산
       Dloss_real = nn.BCELoss()(torch.squeeze(real), label)
       Dloss_real.backward()

       # 가짜 이미지 생성
       noise = torch.randn(label.shape[0], 100, 1, 1, device=device)
       fake = G(noise)
      
       # 가짜 이미지를 감별자에 입력
       output = D(fake.detach())
      
       # 가짜 이미지에 대한 감별자의 오차를 계산
       Dloss_fake = nn.BCELoss()(torch.squeeze(output), label_fake)
       Dloss_fake.backward()
      
       # 감별자의 전체 오차를 학습
       Dloss = Dloss_real + Dloss_fake
       D_optim.step()

       # 생성자의 학습
       G_optim.zero_grad()
       output = D(fake)
       Gloss = nn.BCELoss()(torch.squeeze(output), label)
       Gloss.backward()

       G_optim.step()

       iterator.set_description(f"epoch:{epochs} iteration:{i} D_loss:{Dloss} G_loss:{Gloss}")

torch.save(G.state_dict(), "Generator.pth")
torch.save(D.state_dict(), "Discriminator.pth")

with torch.no_grad():
   G.load_state_dict(
       torch.load("./Generator.pth", map_location=device))

   # 특징 공간 상의 랜덤한 하나의 점을 지정
   feature_vector = torch.randn(1, 100, 1, 1).to(device)

   # 이미지 생성
   pred = G(feature_vector).squeeze()
   pred = pred.permute(1, 2, 0).cpu().numpy()

   plt.imshow(pred)
   plt.title("predicted image")
   plt.show()

마무리
GAN은 사람의 인지적 활동을 모사하는 데 탁월한 성능을 보여준다. 특히, 예술 분야에서 GAN은 많은 관심을 받고 있는 흥미로운 학습 방법이다. 

레퍼런스





댓글 없음:

댓글 쓰기