개발자로 후회없는 삶 살기

논문 분석 PART.Tensorflow로 DCGAN 구현하고 학습하기 본문

[AI]/[논문 리뷰, 분석]

논문 분석 PART.Tensorflow로 DCGAN 구현하고 학습하기

몽이장쥰 2022. 12. 30. 18:28

서론

 

이번 포스팅에서는 Tensorflow로 구현한 DCGAN을 몽타주 dataset으로 학습한 후, 학습된 generator이 생성한 가짜 이미지를 확인하는 것을 목표로 합니다. 작업 환경은 Google Colab에서 진행합니다.

 

-> 전체 코드

https://github.com/SangBeom-Hahn/BOAZ/tree/main/GanStudy/DCGAN/montage_project

 

GitHub - SangBeom-Hahn/BOAZ

Contribute to SangBeom-Hahn/BOAZ development by creating an account on GitHub.

github.com

 

본론

1. 학습에 필요한 util 함수들 정의

2. 데이터셋 로드

3. 모델 구축

4. 학습

5. generator이 생성한 가짜 이미지 확인하기

 

 

- 참고한 코드 : 

https://github.com/carpedm20/DCGAN-tensorflow

 

GitHub - carpedm20/DCGAN-tensorflow: A tensorflow implementation of "Deep Convolutional Generative Adversarial Networks"

A tensorflow implementation of "Deep Convolutional Generative Adversarial Networks" - GitHub - carpedm20/DCGAN-tensorflow: A tensorflow implementation of "Deep Convolutional Generati...

github.com

star 수가 굉장히 많은 DCGAN 깃허브의 모델 구현 부분만 참고했습니다.

 

- 라이브러리 정의

%matplotlib inline

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Activation, BatchNormalization, Dense, Dropout, Flatten, Reshape
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import glob
import cv2
import tensorflow as tf

 

 

1. 학습에 필요한 util 함수들 정의

# 이미지를 저장하는 함수
def sample_images(generator, image_grid_rows=4, image_grid_columns=4):

    # 랜덤한 잡음 샘플링
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))

    # 랜덤한 잡음에서 이미지 생성하기
    gen_imgs = generator.predict(z)

    # 이미지 픽셀 값을 [0, 1] 사이로 스케일 조정
    gen_imgs = 0.5 * gen_imgs + 0.5

    # 이미지 그리드 설정
    fig, axs = plt.subplots(image_grid_rows,
                            image_grid_columns,
                            figsize=(4, 4),
                            sharey=True,
                            sharex=True)

    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            # 이미지 그리드 출력
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

# 그래프를 생성하는 함수
def plotLoss(G_loss, D_loss, epoch):
  cur_dir = os.getcwd()
  loss_dir = "loss_graph"
  file_name = 'gan_loss_epoch_%d.png' % epoch
  dir = os.path.join(cur_dir, loss_dir) 
  os.makedirs(dir, exist_ok = True)

  file_path = os.path.join(dir, file_name)

  plt.figure(figsize=(10, 8))
  plt.plot(D_loss, label='Discriminitive loss')
  plt.plot(G_loss, label='Generative loss')
  plt.xlabel('BatchCount')
  plt.ylabel('Loss')
  plt.legend()
  plt.savefig(file_path)

 

 

 

 

 

 

2. 데이터셋 로드

# 라이브러리 로딩
import glob
import numpy as np
import matplotlib.pyplot as plt

paths = glob.glob('/content/BOAZ/GanStudy/4장_DCGAN/photo/*.jpg')
paths = np.random.permutation(paths)
X = np.array([plt.imread(paths[i]) for i in range(len(paths))]) # 이 코드는 경로의 개수만큼 impread해서 np.array 하는 것으로 확인

출처 : https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=618 

 

AI-Hub

샘플 데이터 ? ※샘플데이터는 데이터의 이해를 돕기 위해 별도로 가공하여 제공하는 정보로써 원본 데이터와 차이가 있을 수 있으며, 데이터에 따라서 민감한 정보는 일부 마스킹(*) 처리가 되

www.aihub.or.kr

Ai-Hub의 몽타주 데이터 셋을 사용하였습니다.

 

 

3. 모델 구축

1) generator

def build_generator(z_dim):

    model = Sequential()
    model.add(Dense(256 * 7 * 7, input_dim=z_dim))
    model.add(Reshape((7, 7, 256)))

    # 7x7x256에서 14x14x128 텐서로 바꾸는 전치 합성곱 층
    model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))

    # 14x14x128에서 14x14x64 텐서로 바꾸는 전치 합성곱 층
    model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))

    # 14x14x64에서 28x28x3 텐서로 바꾸는 전치 합성곱 층
    model.add(Conv2DTranspose(3, kernel_size=3, strides=2, padding='same'))
    model.add(Activation('tanh'))

    return model

입력 데이터의 shape과 모델 Input 단의 shape를 통일한다. 데이터의 복잡도가 크므로 모델의 용량 또한 비례합니다.

 

 

2) discriminator

def build_discriminator(img_shape):

    model = Sequential()

    # 28x28x3에서 14x14x32 텐서로 바꾸는 합성곱 층
    model.add(
        Conv2D(32,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))

    # LeakyReLU 활성화 함수
    model.add(LeakyReLU(alpha=0.01))

    # 14x14x32에서 7x7x64 텐서로 바꾸는 합성곱 층
    model.add(
        Conv2D(64,
               kernel_size=3,
               strides=2,
               padding='same'))

    # LeakyReLU 활성화 함수
    model.add(LeakyReLU(alpha=0.01))

    # 7x7x64에서 3x3x128 텐서로 바꾸는 합성곱 층
    model.add(
        Conv2D(128,
               kernel_size=3,
               strides=2,
               padding='same'))

    # LeakyReLU 활성화 함수
    model.add(LeakyReLU(alpha=0.01))
    
    # sigmoid 활성화 함수를 사용한 출력층
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))

    return model

판별자는 feature를 추출하는 일반 2진 분류 모델과 같은 모양입니다.

 

4. 학습

1) DCGAN 모델 선언

def build_gan(generator, discriminator):

    model = Sequential()

    # 생성자 -> 판별자로 연결된 모델
    model.add(generator)
    model.add(discriminator)

    return model
# 판별자 모델을 만들고 컴파일하기
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(),
                      metrics=['accuracy'])

# 생성자 모델 만들기
generator = build_generator(z_dim)

# 생성자를 훈련하는 동안 판별자의 파라미터를 유지
discriminator.trainable = False

# 생성자를 훈련하기 위해 동결된 판별자로 GAN 모델을 만들고 컴파일
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

 

 

2) train 코드

def train(iterations, batch_size, sample_interval, X):

    # 몽타주 데이터 로드
    # [0, 255] 흑백 픽셀 값을 [-1, 1] 사이로 스케일 조정
    print(X.shape)
    X = X / 127.5 - 1.0

    # 진짜 이미지 레이블: 모두 1
    real = np.ones((batch_size, 1))

    # 가짜 이미지 레이블: 모두 0
    fake = np.zeros((batch_size, 1))
for iteration in range(iterations):
        
        #  판별자 훈련
        # 진짜 이미지에서 랜덤 배치 가져오기
        idx = np.random.randint(0, X.shape[0], batch_size)
        imgs = X[idx]

        # 가짜 이미지 배치 생성
        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)

        # 판별자 훈련
        d_loss_real = discriminator.train_on_batch(imgs, real)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)


        #  생성자 훈련
        # 가짜 이미지 배치 생성
        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)

        # 생성자 훈련
        g_loss = gan.train_on_batch(z, real)

        if (iteration + 1) % sample_interval == 0:

            # 훈련이 끝난 후 그래프를 그리기 위해 손실과 정확도 저장
            
            d_loss_list.append(d_loss)
            g_loss_list.append(g_loss)
            accuracies.append(100.0 * accuracy)
            iteration_checkpoints.append(iteration + 1)

            # 훈련 과정 출력
            print("%d [D 손실: %f, 정확도: %.2f%%] [G 손실: %f]" %
                  (iteration + 1, d_loss, 100.0 * accuracy, g_loss))

            # 생성된 이미지 샘플 출력
            sample_images(generator)

    plotLoss(g_loss_list, d_loss_list, iterations)

 

 

3) 손실 그래프 출력

모델의 용량이 큰 만큼 학습이 오래 걸립니다. 10000 에폭까지는 규형을 이루다가 12000 에폭부터 손실이 증가합니다. 마지막에 20000 에폭에서도 손실이 증가하는 것으로 보아 에폭을 더 늘려야겠습니다.

 

5. generator이 생성한 가짜 이미지 확인하기

 

결과를 보아도 역시 좀 더 많은 학습이 필요해 보입니다.

Comments