end0tknr's kipple - web写経開発

太宰府天満宮の狛犬って、妙にカワイイ

STL-10データを使用した 敵対的生成ネットワーク ( GAN : Generative Adversarial Networks )

GitHub - miyamotok0105/pytorch_handbook: pytorch_handbook

上記urlの6章を写経。

deep learning による画像生成とは、GAN を使用しているらしい。

f:id:end0tknr:20191022101807p:plain

GANの学習不安定を改善する為、その後、DCGANやLSGANが現れたそうですが、 今回、STL-10データを使用し、LSGAN による画像生成を実施。

#!/usr/local/python3/bin/python3
# -*- coding: utf-8 -*-

import os
# ↓ $ sudo /usr/local/python3/bin/pip install google.colab
from google.colab import drive
import random
import numpy as np
# import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image

# 本src内にcopyしました
# from net import weights_init, Generator, Discriminator

# import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import matplotlib.pylab as plt

# google colaboratory なら google drive を mount 可
gdrive_mount_point = None
gdrive_path = None
#gdrive_mount_point = '/content/gdrive'
#gdrive_path = \
#   '/content/gdrive/My Drive/Colab Notebooks/pytorch_handbook/chapter6/'

workers = 2
batch_size=50
nz = 100
nch_g = 64
nch_d = 64
n_epoch = 200
lr = 0.0002
beta1 = 0.5
outf = './result_lsgan' # LSGANにより作成された画像の保存先
display_interval = 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def main():
    print("DEVICE:", device)

    avoid_Imabe_colab_bug()
    mount_and_cd_gdrive()
    make_output_dir()

    random.seed(0)  # 乱数seed固定 (目的は理解していません)
    np.random.seed(0)
    torch.manual_seed(0)

    (dataloader) = load_train_and_test_data()

    (netG,netD) = make_generator_and_discriminator()     # 贋作生成器と識別器
    (optimizerG,optimizerD) = make_optimizer(netG,netD)  # optimizer

    criterion = nn.MSELoss()    # 損失関数は平均二乗誤差損失
    # 確認用の固定したノイズ
    fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)

    # LSGAN の学習が 1 epoch 進む毎に画像fileを自動保存します
    train_lsgan(dataloader,netG,netD,optimizerD,optimizerG,criterion,fixed_noise)

def make_optimizer(netG,netD):
    optimizerG = optim.Adam(netG.parameters(),
                            lr=lr,
                            betas=(beta1, 0.999),
                            weight_decay=1e-5) 
    optimizerD = optim.Adam(netD.parameters(),
                            lr=lr,
                            betas=(beta1, 0.999),
                            weight_decay=1e-5)
    return optimizerG,optimizerD

def train_lsgan(dataloader,netG,netD,optimizerD,optimizerG,criterion,fixed_noise):
    # 学習のループ
    for epoch in range(n_epoch):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)     # 元画像
            sample_size = real_image.size(0)    # 画像枚数
            
            # 正規分布からノイズを生成
            noise = torch.randn(sample_size, nz, 1, 1, device=device)
            # 元画像に対する識別信号の目標値「1」
            real_target = torch.full((sample_size,), 1., device=device)
            # 贋作画像に対する識別信号の目標値「0」
            fake_target = torch.full((sample_size,), 0., device=device)

            ############################
            # 識別器Dの更新
            ###########################
            netD.zero_grad()    # 勾配の初期化

            output = netD(real_image)   # 識別器Dで元画像に対する識別信号を出力
            # 元画像に対する識別信号の損失値
            errD_real = criterion(output, real_target)
            D_x = output.mean().item()

            fake_image = netG(noise)    # 生成器Gでノイズから贋作画像を生成
            # 識別器Dで元画像に対する識別信号を出力
            output = netD(fake_image.detach())
            # 贋作画像に対する識別信号の損失値
            errD_fake = criterion(output, fake_target)
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake    # 識別器Dの全体の損失
            errD.backward()    # 誤差逆伝播
            optimizerD.step()   # Dのパラメーターを更新

            ############################
            # 生成器Gの更新
            ###########################
            netG.zero_grad()    # 勾配の初期化

            # 更新した識別器Dで改めて贋作画像に対する識別信号を出力
            output = netD(fake_image)
            # 生成器Gの損失値。Dに贋作画像を元画像と誤認させたいため目標値は「1」
            errG = criterion(output, real_target)
            errG.backward()     # 誤差逆伝播
            D_G_z2 = output.mean().item()

            optimizerG.step()   # Gのパラメータを更新

            if itr % display_interval == 0: 
                print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                      .format(epoch + 1, n_epoch,
                              itr + 1, len(dataloader),
                              errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            if epoch == 0 and itr == 0:     # 初回に元画像を保存する
                vutils.save_image(real_image, '{}/real_samples.png'.format(outf),
                                  normalize=True, nrow=10)

        ############################
        # 確認用画像の生成
        ############################
        # 1エポック終了ごとに確認用の贋作画像を生成する
        fake_image = netG(fixed_noise)
        vutils.save_image(fake_image.detach(),
                          '{}/fake_samples_epoch_{:03d}.png'.format(outf,
                                                                    epoch + 1),
                          normalize=True,
                          nrow=10)

        ############################
        # モデルの保存
        ############################
        if (epoch + 1) % 50 == 0:   # 50エポックごとにモデルを保存する
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(outf, epoch + 1))
    
def make_generator_and_discriminator():
    # 生成器G。ランダムベクトルから贋作画像を生成する
    netG = Generator(nz=nz, nch_g=nch_g).to(device)
    netG.apply(weights_init)    # weights_init関数で初期化
    # print(netG)


    # 識別器D。画像が、元画像か贋作画像かを識別する
    netD = Discriminator(nch_d=nch_d).to(device)
    netD.apply(weights_init)
    # print(netD)
    return netG, netD

def load_train_and_test_data():
    # STL-10のtrain & test data (stl10_binary.tar.gz 2.5GB)を download & read
    trainset = dset.STL10(root='./dataset/stl10_root',
                          download=True,
                          # labalを使用しない為
                          # labelなしを混在した'train+unlabeled'を使用
                          split='train+unlabeled',
                          transform=transforms.Compose([
                              transforms.RandomResizedCrop(64,
                                                           scale=(88/96, 1.0),
                                                           ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ColorJitter(brightness=0.05,
                                                     contrast=0.05,
                                                     saturation=0.05,
                                                     hue=0.05),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5)),
                          ]))
    testset = dset.STL10(root='./dataset/stl10_root',
                         download=True,
                         split='test',
                         transform=transforms.Compose([
                             transforms.RandomResizedCrop(64,
                                                          scale=(88/96, 1.0),
                                                          ratio=(1., 1.)),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(brightness=0.05,
                                                    contrast=0.05,
                                                    saturation=0.05,
                                                    hue=0.05),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5),
                                                  (0.5, 0.5, 0.5)),
                         ]))
    # STL-10の train dataとtest dataを合わせ訓練データとする
    dataset = trainset + testset

    # 訓練データをセットしたデータローダを作成
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=int(workers))
    return dataloader
    
def make_output_dir():
    try:
        os.makedirs(outf, exist_ok=True)
    except OSError as error:
        print(error)
        pass

    
# colab固有のerror回避の為らしい
def avoid_Imabe_colab_bug():
    Image.register_extension = register_extension
    Image.register_extensions = register_extensions
    
def register_extension(id, extension): 
    Image.EXTENSION[extension.lower()] = id.upper()

def register_extensions(id, extensions): 
    for extension in extensions: 
        register_extension(id, extension)

def mount_and_cd_gdrive():
    if gdrive_mount_point:
        drive.mount(gdrive_mount_point)
        os.chdir(gdrive_path)
    print(os.getcwd())

def weights_init(m):
    """
    ニューラルネットワークの重みを初期化する。
    作成したインスタンスに対しapplyメソッドで適用する
    :param m: ニューラルネットワークを構成する層
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:            # 畳み込み層の場合
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:        # 全結合層の場合
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:     # バッチノーマライゼーションの場合
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class Generator(nn.Module):
    """
    生成器Gのクラス 
    """
    def __init__(self, nz=100, nch_g=64, nch=3):
        """
        :param nz: 入力ベクトルzの次元
        :param nch_g: 最終層の入力チャネル数
        :param nch: 出力画像のチャネル数
        """
        super(Generator, self).__init__()
        
        # ニューラルネットワークの構造を定義する
        self.layers = nn.ModuleDict({
            'layer0': nn.Sequential(
                nn.ConvTranspose2d(nz, nch_g * 8, 4, 1, 0),     # 転置畳み込み
                nn.BatchNorm2d(nch_g * 8),                      # バッチノーマライゼーション
                nn.ReLU()                                       # 正規化線形関数
            ),  # (B, nz, 1, 1) -> (B, nch_g*8, 4, 4)
            'layer1': nn.Sequential(
                nn.ConvTranspose2d(nch_g * 8, nch_g * 4, 4, 2, 1),
                nn.BatchNorm2d(nch_g * 4),
                nn.ReLU()
            ),  # (B, nch_g*8, 4, 4) -> (B, nch_g*4, 8, 8)
            'layer2': nn.Sequential(
                nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 4, 2, 1),
                nn.BatchNorm2d(nch_g * 2),
                nn.ReLU()
            ),  # (B, nch_g*4, 8, 8) -> (B, nch_g*2, 16, 16)

            'layer3': nn.Sequential(
                nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
                nn.BatchNorm2d(nch_g),
                nn.ReLU()
            ),  # (B, nch_g*2, 16, 16) -> (B, nch_g, 32, 32)
            'layer4': nn.Sequential(
                nn.ConvTranspose2d(nch_g, nch, 4, 2, 1),
                nn.Tanh()
            )   # (B, nch_g, 32, 32) -> (B, nch, 64, 64)
        })

    def forward(self, z):
        """
        順方向の演算
        :param z: 入力ベクトル
        :return: 生成画像
        """
        for layer in self.layers.values():  # self.layersの各層で演算を行う
            z = layer(z)
        return z


class Discriminator(nn.Module):
    """
    識別器Dのクラス
    """
    def __init__(self, nch=3, nch_d=64):
        """
        :param nch: 入力画像のチャネル数
        :param nch_d: 先頭層の出力チャネル数
        """
        super(Discriminator, self).__init__()

        # ニューラルネットワークの構造を定義する
        self.layers = nn.ModuleDict({
            'layer0': nn.Sequential(
                nn.Conv2d(nch, nch_d, 4, 2, 1),     # 畳み込み
                nn.LeakyReLU(negative_slope=0.2)    # leaky ReLU関数
            ),  # (B, nch, 64, 64) -> (B, nch_d, 32, 32)
            'layer1': nn.Sequential(
                nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
                nn.BatchNorm2d(nch_d * 2),
                nn.LeakyReLU(negative_slope=0.2)
            ),  # (B, nch_d, 32, 32) -> (B, nch_d*2, 16, 16)
            'layer2': nn.Sequential(
                nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1),
                nn.BatchNorm2d(nch_d * 4),
                nn.LeakyReLU(negative_slope=0.2)
            ),  # (B, nch_d*2, 16, 16) -> (B, nch_d*4, 8, 8)
            'layer3': nn.Sequential(
                nn.Conv2d(nch_d * 4, nch_d * 8, 4, 2, 1),
                nn.BatchNorm2d(nch_d * 8),
                nn.LeakyReLU(negative_slope=0.2)
            ),  # (B, nch_d*4, 8, 8) -> (B, nch_g*8, 4, 4)
            'layer4': nn.Conv2d(nch_d * 8, 1, 4, 1, 0)
            # (B, nch_d*8, 4, 4) -> (B, 1, 1, 1)
        })

    def forward(self, x):
        """
        順方向の演算
        :param x: 元画像あるいは贋作画像
        :return: 識別信号
        """
        for layer in self.layers.values():  # self.layersの各層で演算を行う
            x = layer(x)
        return x.squeeze()     # Tensorの形状を(B)に変更して戻り値とする



if __name__ == '__main__':
    main()

↑こちらを実行すると、以下の画像ファイルが生成されます。

google colab の gpu付環境で、6時間程、学習させ、80epochの時点で停止させました。 80epochの学習程度では、まだまだといった印象です。

サンプル画像 f:id:end0tknr:20191022102148p:plain

1epoch後の画像  f:id:end0tknr:20191022102241p:plain

81epoch後の画像 f:id:end0tknr:20191022102255p:plain