개발자로 후회없는 삶 살기

논문 분석 PART.Tensorflow로 CycleGAN 구현하고 학습하기 본문

[AI]/[논문 리뷰, 분석]

논문 분석 PART.Tensorflow로 CycleGAN 구현하고 학습하기

몽이장쥰 2022. 12. 30. 19:16

서론

이번 포스팅에서는 Tensorflow로 구현한 DCGAN을 fashion dataset으로 학습한 후, 학습된 generator이 생성한 가짜 이미지를 확인하는 것을 목표로 합니다. 작업 환경은 Google Colab에서 진행합니다.

 

-> 전체 코드

https://github.com/SangBeom-Hahn/BOAZ/tree/main/GanStudy/CycleGAN/fashion_cycle_gan_project

 

GitHub - SangBeom-Hahn/BOAZ

Contribute to SangBeom-Hahn/BOAZ development by creating an account on GitHub.

github.com

 

본론

1. 학습에 필요한 util 함수들 정의

2. 데이터셋 로드

3. edge detection

4. 모델 구축

5. 학습

6. generator이 생성한 가짜 이미지 확인하기

 

- 참고한 코드 : 

https://github.com/LynnHo/CycleGAN-Tensorflow-2/blob/master/train.py

 

GitHub - LynnHo/CycleGAN-Tensorflow-2

Contribute to LynnHo/CycleGAN-Tensorflow-2 development by creating an account on GitHub.

github.com

위 깃허브의 Cycle GAN 모델 구현 부분만 참고했습니다.

 

- 라이브러리 정의

from __future__ import print_function, division
import scipy
from tensorflow.keras.datasets import mnist
from tensorflow_addons.layers import InstanceNormalization
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os

 

1. 학습에 필요한 util 함수들 정의

# 이미지를 저장하는 함수
def sample_images(self, epoch, batch_i):
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)
        
        # 이미지를 다른 도메인으로 변환
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # 원본 도메인으로 되돌린다.
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # 이미지를 0 - 1 사이로 스케일을 바꿈
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.show()
# 데이터셋 클래스
class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = resize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = resize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))

        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = resize(img_A, self.img_res)
                img_B = resize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def imread(self, path):
        return imageio.imread(path, pilmode='RGB').astype(np.float)

 

2. 데이터셋 로드

import json

# 이미지 주소 저장
FILE_PATH = "/content/drive/MyDrive/학교 계정/프로그래밍 언어 공부 자료/인공지능/BOAZ 18기/ADV 방학/ADV 방학 세션/GAN 공부 실습/sketch2fashion/FEIDEGGER_release_1.2.json"
file = json.load(open(FILE_PATH))
image = [i["url"] for i in file ]


# 이미지 주소의 jpg 다운
def save_img():
  i = 1
  for url in image:
    i += 1
    urllib.request.urlretrieve(url, f"test{i}.jpg")

 

출처 : https://github.com/zalandoresearch/feidegger

 

GitHub - zalandoresearch/feidegger: A Multi-modal Corpus of Fashion Images and Descriptions in German

A Multi-modal Corpus of Fashion Images and Descriptions in German - GitHub - zalandoresearch/feidegger: A Multi-modal Corpus of Fashion Images and Descriptions in German

github.com

외국의 데이터 구축 시스템으로 오픈되어있는 fashion 데이터셋을 활용하였습니다.

 

3. edge detection

# canny edge detection 메서드
def detect_edges(img):
    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img_gray = cv2.bilateralFilter(img_gray, 5, 50, 50)
    img_gray_edges = cv2.Canny(img_gray, 45, 100)
    img_gray_edges = cv2.bitwise_not(img_gray_edges) # invert black/white
    img_edges = cv2.cvtColor(img_gray_edges, cv2.COLOR_GRAY2RGB)
    
    return img_edges
# 이미지를 가져와서 target 폴더에 저장하는 함수 2개
def create_edge_imgs(target_dir, source_dir):
    pathname = f'{target_dir}/*.jpg' # target_dir에 폴더명
    for filepath in glob.glob(pathname):
        img_target = load_img(filepath, target_size=(256, 256))
        img_target = np.array(img_target)
        img_source = detect_edges(img_target) 

        filename = os.path.basename(filepath)
        img_source_filepath = os.path.join(source_dir, filename)
        save_img(img_source_filepath, img_source)

Cycle GAN은 vanila GAN과 다르게 생성자 두 쌍, 판별자 두 쌍으로 구성되어 있기 때문에 pair 데이터가 필요합니다. 필자가 준비한 pair 데이터는 실제 옷 이미지와 실제 옷의 테두리만 추출한(edge detection) 데이터입니다.

 

4. 모델 구축

1) generator

class CycleGAN(CycleGAN):
      @staticmethod
      def conv2d(layer_input, filters, f_size=4, normalization=True):
        """다운샘플링하는 동안 사용되는 층"""
        d = Conv2D(filters, kernel_size=f_size,
                   strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if normalization:
            d = InstanceNormalization()(d)
        return d
      
      @staticmethod
      def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """업샘플링하는 동안 사용되는 층"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1,
                       padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u
class CycleGAN(CycleGAN):
    def build_generator(self):
        """U-Net 생성자"""
        # 이미지 입력
        d0 = Input(shape=self.img_shape)

        # 다운샘플링
        d1 = self.conv2d(d0, self.gf)
        d2 = self.conv2d(d1, self.gf * 2)
        d3 = self.conv2d(d2, self.gf * 4)
        d4 = self.conv2d(d3, self.gf * 8)

        # 업샘플링
        u1 = self.deconv2d(d4, d3, self.gf * 4)
        u2 = self.deconv2d(u1, d2, self.gf * 2)
        u3 = self.deconv2d(u2, d1, self.gf)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4,
                            strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img)

 

 

2) discriminator

class CycleGAN(CycleGAN):
    def build_discriminator(self):
      img = Input(shape=self.img_shape)

      d1 = self.conv2d(img, self.df, normalization=False) #필터 개수를 두배씩 늘린다.
      d2 = self.conv2d(d1, self.df * 2)
      d3 = self.conv2d(d2, self.df * 4)
      d4 = self.conv2d(d3, self.df * 8)

      validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

      return Model(img, validity)

 

5. 학습

class CycleGAN(CycleGAN):
      def train(self, epochs, batch_size=1, sample_interval=50):
        # 적대 손실에 대한 정답
        valid = np.ones((batch_size,) + self.disc_patch) # 1
        fake = np.zeros((batch_size,) + self.disc_patch) # 0


        for epoch in range(epochs):
          # 각 도메인에서 랜덤한 이미지의 미니배치를 만듬
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):


                #  판별자 훈련
                # 이미지를 상대 도메인으로 변환

                #생성자 gab를 사용해 이미지a를 b로 변환하고 gba를 사용해 이미지 b를 도메인 a로 변환
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)



                # 판별자를 훈련 (원본 이미지 = real / 변환된 이미지 = fake)
                #da에 실제 이미지a를 1로 알려주고 손실을 구함
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                #gba가 만든 가짜이미지를 0으로 알려줘서 손실을 구함
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                # 이 두 손실을 더함
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)


                # dB에도 똑같이 함
                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # 위에 두 단계에서 나온 손실을 더해서 판별자 전체 손실 생성
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                #  생성자 훈련
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid,
                                                       imgs_A, imgs_B,
                                                       imgs_A, imgs_B])

                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

왼쪽 : 학습초기/ 오른쪽 : 학습 중반

 

6. generator이 생성한 가짜 이미지 확인하기

왼쪽 : 에폭 10/ 오른쪽 : 에폭 50

1) 에폭 10

왼쪽 사진의 상단을 보면 에폭 10에서 스케치 이미지에 색깔이 완벽하게 입히지 않은 것을 볼 수 있습니다. 하지만 왼쪽 사진의 하단을 보면 실제 이미지가 스케치처럼 살짝 부연 형태를 띠고 있습니다.

 

2) 에폭 50

오른쪽 사진의 상단을 보면 완벽히 채색이 된 것을 볼 수 있습니다. 하지만 색깔을 지정할 수 없다는 단점이 있습니다. 또한 하단 이미지를 보고 스케치라고 확신 있게 말할 수는 없을 것 같습니다.

 

결론

CycleGAN은 완벽히 변환을 하는 것이 아닌 다른 도메인 풍을 보이는 현재 도메인의 이미지를 생성하는 것이기 때문에 스케치를 실제 이미지로 변환했다고 보기는 힘듭니다. ( ex) 현재 도메인이 스케치이고 다른 도메인이 실제라면 실제 풍을 가지고 있는 스케치 이미지)

 

∴ 따라서 Pix2Pix 모델을 사용해 봐야겠습니다.

Comments