Shortcuts

DCGAN 教程|

作者内森·因卡

简介|

本教程将通过示例介绍 DCAN。 我们将训练一个生成对抗网络(GAN),以产生新的名人后,显示许多真正的名人的照片。 这里的大多数代码来自pytorch/示例中的 dcgan 实现,本文档将彻底解释该实现,并阐明此模型的工作原理和原因。 但不要担心,不需要事先了解 DN,但可能需要首次尝试时间来推理引擎盖下实际发生的情况。 此外,为了节省时间,它将有助于有一个GPU,或两个。 让我们从开始。

生成对抗网络|

什么是 GAN?*

GAN 是一个框架,用于教授 DL 模型来捕获训练数据的分布,以便我们可以从同一分布生成新数据。 GANs由伊恩·古德费洛于2014年发明,并首次在论文《生成对抗网》中描述。 它们由两个不同的模型组成,一个生成器和一个鉴别器 生成器的工作是生成类似于训练图像的"假"图像。 鉴别器的工作是查看图像并输出图像,无论它是真正的训练图像还是来自生成器的假图像。 在训练期间,生成器不断试图通过生成更好和更好的假象来超越鉴别者,而鉴别者正在努力成为更好的侦探,并正确分类真实和假图像。 这个游戏的平衡是当发电机产生完美的假象,看起来好像他们直接来自训练数据,和鉴别者总是猜测50%的信心,发电机输出是真实的或假的。

现在,让我们定义一些表示法,从鉴别器开始在整个教程中使用。 [(x)是表示图像的数据。 *(D(x)*是一个鉴别器网络,它输出[(标量)概率]来自训练数据而不是生成器。 在这里,由于我们处理的图像,输入到 *(D(x)*)是 CHW 大小 3x64x64 的图像。 直观地说,当[(x])来自训练数据时,[(x)]应为高,当[(x])来自生成器时,应为"低"。 *(d(x)*)也可以被视为传统的二进制分类器。

对于生成器的表示法,让\(z_)是从标准正态分布采样的投影空间矢量。 \(G(z))表示将潜在矢量(z_)映射到数据空间的生成器函数。 *(G+)的目标是估计训练数据来自的分布([(p_[数据]),以便它可以从该估计分布(*(p_g)生成假样本。

因此,\(D(G(z))*是生成器输出为真实图像的概率标量)。 Goodfellow 的论文中所述,[(D])[(G])玩一个迷你最大游戏,其中[(D])试图最大化正确分类真实和假数的概率(*(logD(x)))[(G])试图最小化\(D+)预测其输出为假的概率(\(log(1-D)( 从本文中,GAN 损失函数是

\{}下集\G\{}文本\min\{{{d}\{max}v(D,G)=\mathbb_e_s_sp_[数据](x)_大[logD(x)大][数学]_e_z_s_p__z_________________________z_____________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________

从理论上讲,这种小型最大游戏的解决方案是[(p_g = p_[数据]),而鉴别者随机猜测输入是真还是假。 然而,DN的收敛理论仍在积极研究中,在现实中,模型并不总是训练到这个地步。

什么是 DCGAN?*

DCGAN 是上述 GAN 的直接延伸,只不过它分别在鉴别器和发生器中显式使用卷积层和卷积转位层。 它首先由拉德福德等人描述。 al. 在论文无监督表示学习与深层卷积生成对抗网络 鉴别器由大步卷积层、批次规范层和LeakyReLU激活组成。 输入是 3x64x64 输入图像,输出是输入来自实际数据分布的标量概率。 生成器由卷积转置层、批次规范层和ReLU激活组成。 输入是从标准正态分布绘制的潜在向量(z_),输出为 3x64x64 RGB 图像。 大步康转层允许将潜在矢量转换为与图像形状相同的体积。 在本文中,作者还给出了一些关于如何设置优化器、如何计算损失函数以及如何初始化模型权重的提示,所有这些内容都将在接下来的章节中介绍。

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

输出:

Random Seed:  999

输入

让我们为运行定义一些输入:

  • 数据根- 数据集文件夹根的路径。 我们将在下一节中讨论数据集

  • 工作线程- 使用 DataLoader 加载数据的工作线程数

  • batch_size - 培训中使用的批处理大小。 DCGAN 纸张的批次大小为 128

  • image_size - 用于训练的图像的空间大小。 此实现默认为 64x64。 如果需要其他尺寸,则必须更改 D 和 G 的结构。 有关详细信息,请参阅此处

  • nc - 输入图像中的颜色通道数。 对于彩色图像,这是 3

  • nz - 潜伏向量的长度

  • ngf - 与通过生成器携带的要素图的深度有关

  • ndf - 设置通过鉴别器传播的要素地图的深度

  • num_epochs - 要运行的训练纪元数。 培训时间越长,结果可能会更好,但需要更长的时间

  • lr - 培训学习率。 如 DCGAN 文件所述,此数字应为 0.0002

  • beta1 - beta1 超参数,适用于亚当优化器。 如纸张所述,此数字应为 0.5

  • ngpu - 可用的 GPU 数量。 如果为 0,则代码将在 CPU 模式下运行。 如果此数字大于 0,它将在该数量的 GPU 上运行

# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

数据|

在本教程中,我们将使用Celeb-A脸数据集,可以在链接的网站,或在谷歌驱动器下载。 数据集将作为名为img_align_celeba.zip的文件下载。 下载后,创建名为celeba的目录,并将 zip 文件解压缩到该目录中。 然后,将此笔记本的数据根输入设置为刚刚创建的celeba目录。 生成的目录结构应为:

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

这是一个重要的步骤,因为我们将使用 ImageFolder 数据集类,这需要在数据集的根文件夹中存在子目录。 现在,我们可以创建数据集、创建数据加载器、设置设备以运行,并最终可视化一些训练数据。

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
../_images/sphx_glr_dcgan_faces_tutorial_001.png

实施|

设置输入参数并准备好数据集后,我们现在可以进入实现。 我们将从 weigth 初始化策略开始,然后详细讨论生成器、鉴别器、损耗函数和培训循环。

重量初始化|

在 DCGAN 论文中,作者指定所有模型权重应从均值=0、stdev=0.02 的正态分布中随机初始化。 weights_init函数以初始化模型作为输入,并重新初始化所有卷积、卷积转位和批处理规范化层以满足此条件。 此函数在初始化后立即应用于模型。

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

发电机

生成器(G_)旨在将潜在空间矢量(z_)映射到数据空间。 由于我们的数据是图像,因此将[(z])转换为数据空间意味着最终创建与训练图像大小相同的 RGB 图像(即 3x64x64)。 在实践中,这是通过一系列跨越的二维卷积转置层实现的,每个层都与 2d 批处理规范层和 relu 激活配对。 发电机的输出通过 tanh 函数馈送,使其返回到输入数据范围[(+-1,1])。 值得注意的是,在conv转置层之后存在批处理规范函数,这是DCGAN论文的重要贡献。 这些层有助于训练期间的梯度流动。 从 DCGAN 纸张的生成器的图像如下所示。

dcgan_generator

请注意,我们在输入部分(nz、ngfnc)中设置的输入如何影响代码中的生成器体系结构。 nz是 z 输入矢量的长度,ngf与通过生成器传播的要素贴图的大小相关,nc是输出图像中的通道数(RGB 图像设置为 3)。 下面是生成器的代码。

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

现在,我们可以实例化生成器并应用weights_init函数。 查看打印的模型,了解生成器对象的结构。

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)

输出:

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

鉴别器|

如前所述,鉴别器[(D_)]是一个二进制分类网络,它以图像作为输入并输出输入图像为实数(而不是假)的标量概率。 此处,*(D+)获取 3x64x64 输入图像,通过一系列 Conv2d、BatchNorm2d 和 LeakyReLU 图层进行处理,并通过 Sigmoid 激活函数输出最终概率。 如有必要,此体系结构可以扩展为更多层,但使用大步卷积、BatchNorm 和 LeakyReL 具有重要意义。 DCGAN 论文提到,使用大步卷积而不是池到下采样是一种好的做法,因为它让网络学习自己的池函数。 此外,批量规范和漏漏的relu函数促进健康的梯度流动,这对+(G+)+(D)的学习过程都至关重要。

歧视者代码

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

现在,与生成器一样,我们可以创建鉴别器,应用weights_init函数,并打印模型的结构。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

输出:

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

损耗函数和优化器|

通过+(D+)+(G)设置,我们可以指定它们如何通过损失函数和优化器学习。 我们将使用二进制交叉熵损失 (BCELoss) 函数,该函数在 PyTorch 中定义为:

\\ell(x,y) • L = l_1,l_N\\\\\},[四l_n] - [左] y_n \cdot [log x_n ] (1 - y_n) [cdot ]日志 (1 - x_n) [右]

请注意此函数如何提供目标函数中两个日志组件的计算(即*(日志(D(x)))和 +(log(1-D(G(z)))))。)。 我们可以指定要与_(y})输入一起使用的 BCE 方程的哪一部分。 这是在即将启动的培训循环中完成的,但了解我们如何通过更改+(y)(即 GT 标签)来选择我们希望计算的组件非常重要。

接下来,我们将真实标签定义为 1,将假标签定义为 0。 这些标签将用于计算_(D])+(G)的损失,这也是原始 GAN 纸张中使用的约定。 最后,我们设置了两个单独的优化器,一个用于\(D),另一个用于\(G)。 如 DCGAN 论文中所述,两者都是学习速率为 0.0002 和 Beta1 = 0.5 的 Adam 优化器。 为了跟踪生成器的学习进度,我们将生成一组固定的潜在向量,这些向量来自高斯分布(即fixed_noise)。 在训练循环中,我们将定期将此fixed_noise输入到 +(G_)中,在迭代中,我们将看到图像从噪声中形成。

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

培训|

最后,既然我们已经定义了GAN框架的所有部分,我们可以训练它。 请注意,训练 DN 在某种程度上是一种艺术形式,因为不正确的超参数设置会导致模式崩溃,而很少解释出错的原因。 在这里,我们将密切关注算法1从古德费洛的论文,同时遵守一些最佳实践显示在甘哈克 也就是说,我们将"为真实和假图像构建不同的小批量",并调整 G 的目标函数以最大化[(logD(G(z))]。 培训分为两个主要部分。 第 1 部分更新区分器,第 2 部分更新生成器。

第 1 部分 - 培训歧视者

回想一下,培训鉴别者的目的是最大化将给定输入正确分类为真或假的概率。 就古德费洛而言,我们希望"通过提升其随机梯度来更新鉴别者"。 实际上,我们希望最大化[(日志(d(x)) = 日志(1-D(G(z))))。 由于从甘哈克的单独小批量建议,我们将计算这两个步骤。 首先,我们将从训练集构造一批真实样本,向前传递+(D_),计算损耗(*((d(x)),然后计算向后传递中的梯度。 其次,用电流发生器构造一批假样品,通过+(D+)向前传递该批,计算损耗(1-D(G(z))),用向后传递累积梯度。 现在,随着从全真批次和全假批次中累积的梯度,我们称之为区分器优化器的一步。

第 2 部分 - 训练发电机

如原始文件所述,我们希望通过最小化\(log(1-D(G(z))))来训练生成器,以努力生成更好的假象。 如前所述,Goodfellow 表明,这没有提供足够的梯度,特别是在学习过程的早期。 作为修复,我们希望最大化[(日志(D(G(z))))] 在代码中,我们通过以下目的实现:使用区分器对第 1 部分的生成器输出进行分类,使用实际标签计算G 的损耗作为 GT,在向后传递中计算 G 的梯度,最后使用优化器步骤更新 G 的参数。 将真实标签用作丢失函数的 GT 标签似乎有悖常理,但这允许我们使用 BCELoss 的[(日志(x)]部分(而不是*(log(1-x)部分),这正是我们想要的。

最后,我们将做一些统计报告,在每个纪元结束时,我们将推动我们的fixed_noise批次通过生成器,以直观地跟踪G的训练进度。 所报告的培训统计数据如下:

  • Loss_D - 鉴别者损失计算为所有真实批次和所有假批次的损失总和(*(日志(d(x)= 日志(D(G(z)))))。

  • Loss_G - 发电机损耗计算为+(日志(D(G(z)))))))

  • D(x) - 所有实际批次的鉴别器的平均输出(跨批次)。 这应该开始接近1,然后理论上收敛到0.5时,G变得更好。 想想为什么会这样。

  • D(G(z)-所有假批次的平均鉴别器输出。 第一个数字在 D 更新之前,第二个数字在 D 更新之后。 这些数字应该从0附近开始,并收敛到0.5,因为G越来越好。 想想为什么会这样。

注意: 此步骤可能需要一段时间,具体取决于运行的划时代数以及是否从数据集中删除了某些数据。

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

输出:

Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.8664  Loss_G: 4.9949  D(x): 0.5050    D(G(z)): 0.5928 / 0.0106
[0/5][50/1583]  Loss_D: 0.1164  Loss_G: 6.3327  D(x): 0.9821    D(G(z)): 0.0353 / 0.0088
[0/5][100/1583] Loss_D: 0.6500  Loss_G: 8.5506  D(x): 0.8983    D(G(z)): 0.3084 / 0.0008
[0/5][150/1583] Loss_D: 0.3882  Loss_G: 2.3167  D(x): 0.8348    D(G(z)): 0.0826 / 0.1773
[0/5][200/1583] Loss_D: 0.4996  Loss_G: 4.8283  D(x): 0.7188    D(G(z)): 0.0304 / 0.0122
[0/5][250/1583] Loss_D: 0.7112  Loss_G: 3.0257  D(x): 0.6858    D(G(z)): 0.1269 / 0.0874
[0/5][300/1583] Loss_D: 0.6014  Loss_G: 2.9924  D(x): 0.7573    D(G(z)): 0.2083 / 0.0730
[0/5][350/1583] Loss_D: 1.0553  Loss_G: 2.3669  D(x): 0.5003    D(G(z)): 0.0628 / 0.1624
[0/5][400/1583] Loss_D: 0.7078  Loss_G: 5.3654  D(x): 0.9079    D(G(z)): 0.4053 / 0.0130
[0/5][450/1583] Loss_D: 0.6259  Loss_G: 5.8745  D(x): 0.8932    D(G(z)): 0.3414 / 0.0064
[0/5][500/1583] Loss_D: 0.6442  Loss_G: 3.9164  D(x): 0.7826    D(G(z)): 0.2192 / 0.0359
[0/5][550/1583] Loss_D: 0.5650  Loss_G: 4.2448  D(x): 0.8038    D(G(z)): 0.2045 / 0.0262
[0/5][600/1583] Loss_D: 0.8724  Loss_G: 8.5326  D(x): 0.9339    D(G(z)): 0.4805 / 0.0006
[0/5][650/1583] Loss_D: 0.2510  Loss_G: 4.6053  D(x): 0.9121    D(G(z)): 0.1211 / 0.0174
[0/5][700/1583] Loss_D: 0.3778  Loss_G: 3.9419  D(x): 0.7693    D(G(z)): 0.0538 / 0.0308
[0/5][750/1583] Loss_D: 0.4424  Loss_G: 4.4343  D(x): 0.9226    D(G(z)): 0.2667 / 0.0174
[0/5][800/1583] Loss_D: 0.2920  Loss_G: 4.5266  D(x): 0.8288    D(G(z)): 0.0636 / 0.0200
[0/5][850/1583] Loss_D: 0.8617  Loss_G: 7.8081  D(x): 0.9332    D(G(z)): 0.4871 / 0.0010
[0/5][900/1583] Loss_D: 0.4552  Loss_G: 3.1839  D(x): 0.7460    D(G(z)): 0.0812 / 0.0722
[0/5][950/1583] Loss_D: 0.2615  Loss_G: 4.3918  D(x): 0.8956    D(G(z)): 0.1195 / 0.0237
[0/5][1000/1583]        Loss_D: 0.8312  Loss_G: 1.9589  D(x): 0.5710    D(G(z)): 0.0525 / 0.2147
[0/5][1050/1583]        Loss_D: 0.3988  Loss_G: 5.1550  D(x): 0.8943    D(G(z)): 0.2070 / 0.0102
[0/5][1100/1583]        Loss_D: 0.4076  Loss_G: 2.7323  D(x): 0.8034    D(G(z)): 0.1133 / 0.1005
[0/5][1150/1583]        Loss_D: 1.2152  Loss_G: 1.0146  D(x): 0.4353    D(G(z)): 0.0423 / 0.4245
[0/5][1200/1583]        Loss_D: 0.7570  Loss_G: 6.6842  D(x): 0.8981    D(G(z)): 0.4155 / 0.0035
[0/5][1250/1583]        Loss_D: 0.5240  Loss_G: 4.2630  D(x): 0.8281    D(G(z)): 0.2270 / 0.0226
[0/5][1300/1583]        Loss_D: 0.4874  Loss_G: 3.8270  D(x): 0.8750    D(G(z)): 0.2513 / 0.0325
[0/5][1350/1583]        Loss_D: 0.4722  Loss_G: 4.4401  D(x): 0.9174    D(G(z)): 0.2869 / 0.0186
[0/5][1400/1583]        Loss_D: 0.4843  Loss_G: 3.1688  D(x): 0.8548    D(G(z)): 0.2186 / 0.0695
[0/5][1450/1583]        Loss_D: 0.8983  Loss_G: 6.8422  D(x): 0.8979    D(G(z)): 0.4802 / 0.0026
[0/5][1500/1583]        Loss_D: 0.4259  Loss_G: 3.0975  D(x): 0.7808    D(G(z)): 0.1153 / 0.0658
[0/5][1550/1583]        Loss_D: 0.5618  Loss_G: 3.2871  D(x): 0.8243    D(G(z)): 0.2453 / 0.0603
[1/5][0/1583]   Loss_D: 0.4305  Loss_G: 3.4571  D(x): 0.7496    D(G(z)): 0.0625 / 0.0508
[1/5][50/1583]  Loss_D: 0.3509  Loss_G: 3.8943  D(x): 0.8236    D(G(z)): 0.0963 / 0.0344
[1/5][100/1583] Loss_D: 0.3889  Loss_G: 3.5543  D(x): 0.7929    D(G(z)): 0.1042 / 0.0470
[1/5][150/1583] Loss_D: 1.1344  Loss_G: 5.9704  D(x): 0.9539    D(G(z)): 0.6008 / 0.0054
[1/5][200/1583] Loss_D: 0.8094  Loss_G: 5.2500  D(x): 0.9470    D(G(z)): 0.4703 / 0.0097
[1/5][250/1583] Loss_D: 0.5325  Loss_G: 4.6564  D(x): 0.8822    D(G(z)): 0.2936 / 0.0155
[1/5][300/1583] Loss_D: 0.4475  Loss_G: 3.5293  D(x): 0.8081    D(G(z)): 0.1607 / 0.0403
[1/5][350/1583] Loss_D: 0.6466  Loss_G: 4.7204  D(x): 0.9084    D(G(z)): 0.3730 / 0.0152
[1/5][400/1583] Loss_D: 0.4475  Loss_G: 4.4130  D(x): 0.9334    D(G(z)): 0.2870 / 0.0199
[1/5][450/1583] Loss_D: 0.5445  Loss_G: 4.2882  D(x): 0.9132    D(G(z)): 0.3155 / 0.0246
[1/5][500/1583] Loss_D: 0.4090  Loss_G: 3.3560  D(x): 0.7937    D(G(z)): 0.1252 / 0.0511
[1/5][550/1583] Loss_D: 0.7900  Loss_G: 3.2994  D(x): 0.6804    D(G(z)): 0.2317 / 0.0648
[1/5][600/1583] Loss_D: 0.8600  Loss_G: 2.1912  D(x): 0.5798    D(G(z)): 0.1380 / 0.1600
[1/5][650/1583] Loss_D: 0.5619  Loss_G: 2.2430  D(x): 0.7028    D(G(z)): 0.1054 / 0.1410
[1/5][700/1583] Loss_D: 0.4090  Loss_G: 2.9897  D(x): 0.7556    D(G(z)): 0.0720 / 0.0819
[1/5][750/1583] Loss_D: 1.0678  Loss_G: 5.9937  D(x): 0.9666    D(G(z)): 0.5960 / 0.0045
[1/5][800/1583] Loss_D: 0.5858  Loss_G: 3.5338  D(x): 0.8912    D(G(z)): 0.3212 / 0.0533
[1/5][850/1583] Loss_D: 0.5105  Loss_G: 2.3191  D(x): 0.6867    D(G(z)): 0.0748 / 0.1284
[1/5][900/1583] Loss_D: 0.4819  Loss_G: 3.6504  D(x): 0.8348    D(G(z)): 0.2189 / 0.0376
[1/5][950/1583] Loss_D: 2.5887  Loss_G: 7.0482  D(x): 0.9870    D(G(z)): 0.8696 / 0.0034
[1/5][1000/1583]        Loss_D: 0.6032  Loss_G: 2.5206  D(x): 0.6502    D(G(z)): 0.0787 / 0.1100
[1/5][1050/1583]        Loss_D: 0.3534  Loss_G: 3.1873  D(x): 0.8904    D(G(z)): 0.1903 / 0.0574
[1/5][1100/1583]        Loss_D: 0.4637  Loss_G: 3.1668  D(x): 0.8057    D(G(z)): 0.1707 / 0.0632
[1/5][1150/1583]        Loss_D: 0.6998  Loss_G: 3.2717  D(x): 0.8345    D(G(z)): 0.3467 / 0.0524
[1/5][1200/1583]        Loss_D: 0.4073  Loss_G: 2.9885  D(x): 0.8461    D(G(z)): 0.1888 / 0.0714
[1/5][1250/1583]        Loss_D: 0.5090  Loss_G: 1.9835  D(x): 0.7045    D(G(z)): 0.0921 / 0.1780
[1/5][1300/1583]        Loss_D: 0.7495  Loss_G: 3.9293  D(x): 0.9240    D(G(z)): 0.4359 / 0.0335
[1/5][1350/1583]        Loss_D: 0.5585  Loss_G: 2.8623  D(x): 0.7890    D(G(z)): 0.2340 / 0.0779
[1/5][1400/1583]        Loss_D: 0.4779  Loss_G: 2.9822  D(x): 0.7271    D(G(z)): 0.0986 / 0.0808
[1/5][1450/1583]        Loss_D: 1.6573  Loss_G: 1.4877  D(x): 0.2696    D(G(z)): 0.0186 / 0.3113
[1/5][1500/1583]        Loss_D: 0.6673  Loss_G: 2.3133  D(x): 0.6610    D(G(z)): 0.1353 / 0.1436
[1/5][1550/1583]        Loss_D: 1.2182  Loss_G: 4.5559  D(x): 0.9353    D(G(z)): 0.6044 / 0.0223
[2/5][0/1583]   Loss_D: 0.5260  Loss_G: 2.3830  D(x): 0.7460    D(G(z)): 0.1696 / 0.1187
[2/5][50/1583]  Loss_D: 0.6826  Loss_G: 2.1603  D(x): 0.6282    D(G(z)): 0.1300 / 0.1499
[2/5][100/1583] Loss_D: 0.5114  Loss_G: 2.4472  D(x): 0.7099    D(G(z)): 0.1078 / 0.1169
[2/5][150/1583] Loss_D: 1.0720  Loss_G: 4.4930  D(x): 0.9212    D(G(z)): 0.5647 / 0.0176
[2/5][200/1583] Loss_D: 0.4492  Loss_G: 2.6146  D(x): 0.8598    D(G(z)): 0.2218 / 0.0988
[2/5][250/1583] Loss_D: 0.6255  Loss_G: 2.4640  D(x): 0.6118    D(G(z)): 0.0442 / 0.1201
[2/5][300/1583] Loss_D: 0.6009  Loss_G: 3.3765  D(x): 0.8803    D(G(z)): 0.3391 / 0.0467
[2/5][350/1583] Loss_D: 0.7020  Loss_G: 3.7924  D(x): 0.8935    D(G(z)): 0.4192 / 0.0294
[2/5][400/1583] Loss_D: 1.9287  Loss_G: 0.4761  D(x): 0.2297    D(G(z)): 0.0488 / 0.6724
[2/5][450/1583] Loss_D: 0.6189  Loss_G: 3.1436  D(x): 0.8498    D(G(z)): 0.3301 / 0.0581
[2/5][500/1583] Loss_D: 0.6772  Loss_G: 3.1658  D(x): 0.8838    D(G(z)): 0.3865 / 0.0577
[2/5][550/1583] Loss_D: 0.9952  Loss_G: 1.1167  D(x): 0.4631    D(G(z)): 0.0971 / 0.3798
[2/5][600/1583] Loss_D: 0.6339  Loss_G: 1.9277  D(x): 0.6725    D(G(z)): 0.1593 / 0.1778
[2/5][650/1583] Loss_D: 0.6164  Loss_G: 2.0238  D(x): 0.6618    D(G(z)): 0.1162 / 0.1631
[2/5][700/1583] Loss_D: 0.6668  Loss_G: 3.4620  D(x): 0.8727    D(G(z)): 0.3600 / 0.0419
[2/5][750/1583] Loss_D: 0.5204  Loss_G: 2.2023  D(x): 0.6693    D(G(z)): 0.0692 / 0.1474
[2/5][800/1583] Loss_D: 0.7820  Loss_G: 0.8576  D(x): 0.5554    D(G(z)): 0.1009 / 0.4613
[2/5][850/1583] Loss_D: 0.8125  Loss_G: 2.5283  D(x): 0.7835    D(G(z)): 0.3851 / 0.1010
[2/5][900/1583] Loss_D: 0.5450  Loss_G: 2.7598  D(x): 0.8645    D(G(z)): 0.2970 / 0.0797
[2/5][950/1583] Loss_D: 0.5866  Loss_G: 3.2705  D(x): 0.8776    D(G(z)): 0.3340 / 0.0526
[2/5][1000/1583]        Loss_D: 0.6416  Loss_G: 3.1707  D(x): 0.8074    D(G(z)): 0.3130 / 0.0564
[2/5][1050/1583]        Loss_D: 0.7074  Loss_G: 1.0624  D(x): 0.5846    D(G(z)): 0.0945 / 0.3938
[2/5][1100/1583]        Loss_D: 1.1414  Loss_G: 5.0316  D(x): 0.8998    D(G(z)): 0.5922 / 0.0102
[2/5][1150/1583]        Loss_D: 0.7948  Loss_G: 1.7607  D(x): 0.5557    D(G(z)): 0.0907 / 0.2180
[2/5][1200/1583]        Loss_D: 0.7528  Loss_G: 3.0219  D(x): 0.8710    D(G(z)): 0.4185 / 0.0622
[2/5][1250/1583]        Loss_D: 1.0173  Loss_G: 2.6637  D(x): 0.6454    D(G(z)): 0.3590 / 0.0959
[2/5][1300/1583]        Loss_D: 0.5150  Loss_G: 3.2834  D(x): 0.8053    D(G(z)): 0.2168 / 0.0596
[2/5][1350/1583]        Loss_D: 0.5156  Loss_G: 3.2547  D(x): 0.8674    D(G(z)): 0.2880 / 0.0496
[2/5][1400/1583]        Loss_D: 0.6287  Loss_G: 2.2635  D(x): 0.7251    D(G(z)): 0.2252 / 0.1301
[2/5][1450/1583]        Loss_D: 0.5374  Loss_G: 3.3125  D(x): 0.8112    D(G(z)): 0.2470 / 0.0514
[2/5][1500/1583]        Loss_D: 0.5558  Loss_G: 2.3804  D(x): 0.7956    D(G(z)): 0.2478 / 0.1176
[2/5][1550/1583]        Loss_D: 1.3468  Loss_G: 0.1723  D(x): 0.3568    D(G(z)): 0.0347 / 0.8570
[3/5][0/1583]   Loss_D: 0.5730  Loss_G: 1.9957  D(x): 0.7489    D(G(z)): 0.2090 / 0.1668
[3/5][50/1583]  Loss_D: 0.7140  Loss_G: 1.1195  D(x): 0.6209    D(G(z)): 0.1602 / 0.3818
[3/5][100/1583] Loss_D: 1.2946  Loss_G: 0.4481  D(x): 0.3415    D(G(z)): 0.0226 / 0.6707
[3/5][150/1583] Loss_D: 0.8376  Loss_G: 3.2857  D(x): 0.8827    D(G(z)): 0.4568 / 0.0486
[3/5][200/1583] Loss_D: 0.5578  Loss_G: 3.6264  D(x): 0.8621    D(G(z)): 0.3113 / 0.0336
[3/5][250/1583] Loss_D: 0.4769  Loss_G: 2.3588  D(x): 0.8131    D(G(z)): 0.2136 / 0.1177
[3/5][300/1583] Loss_D: 0.6321  Loss_G: 2.4122  D(x): 0.8349    D(G(z)): 0.3230 / 0.1159
[3/5][350/1583] Loss_D: 0.6285  Loss_G: 2.0063  D(x): 0.7119    D(G(z)): 0.2033 / 0.1629
[3/5][400/1583] Loss_D: 0.4438  Loss_G: 2.7762  D(x): 0.7756    D(G(z)): 0.1375 / 0.0860
[3/5][450/1583] Loss_D: 0.7211  Loss_G: 2.5315  D(x): 0.7283    D(G(z)): 0.2872 / 0.1019
[3/5][500/1583] Loss_D: 0.5265  Loss_G: 2.0160  D(x): 0.7936    D(G(z)): 0.2246 / 0.1660
[3/5][550/1583] Loss_D: 0.8631  Loss_G: 3.2266  D(x): 0.7293    D(G(z)): 0.3661 / 0.0560
[3/5][600/1583] Loss_D: 0.8367  Loss_G: 0.8854  D(x): 0.5442    D(G(z)): 0.1096 / 0.4430
[3/5][650/1583] Loss_D: 0.5915  Loss_G: 2.5472  D(x): 0.7888    D(G(z)): 0.2597 / 0.1045
[3/5][700/1583] Loss_D: 0.6117  Loss_G: 3.0374  D(x): 0.8493    D(G(z)): 0.3133 / 0.0681
[3/5][750/1583] Loss_D: 0.8155  Loss_G: 1.3866  D(x): 0.6382    D(G(z)): 0.2545 / 0.2864
[3/5][800/1583] Loss_D: 0.5261  Loss_G: 2.9280  D(x): 0.8175    D(G(z)): 0.2480 / 0.0681
[3/5][850/1583] Loss_D: 0.5661  Loss_G: 2.4623  D(x): 0.8358    D(G(z)): 0.2892 / 0.1064
[3/5][900/1583] Loss_D: 0.6336  Loss_G: 3.2586  D(x): 0.8466    D(G(z)): 0.3399 / 0.0510
[3/5][950/1583] Loss_D: 0.5568  Loss_G: 3.3120  D(x): 0.8403    D(G(z)): 0.2827 / 0.0500
[3/5][1000/1583]        Loss_D: 0.5501  Loss_G: 1.8930  D(x): 0.7536    D(G(z)): 0.2043 / 0.1899
[3/5][1050/1583]        Loss_D: 0.8346  Loss_G: 1.3622  D(x): 0.5341    D(G(z)): 0.0915 / 0.3088
[3/5][1100/1583]        Loss_D: 0.6320  Loss_G: 3.1845  D(x): 0.8474    D(G(z)): 0.3426 / 0.0509
[3/5][1150/1583]        Loss_D: 0.7423  Loss_G: 2.2229  D(x): 0.7426    D(G(z)): 0.3102 / 0.1334
[3/5][1200/1583]        Loss_D: 1.0204  Loss_G: 3.7934  D(x): 0.8869    D(G(z)): 0.5414 / 0.0365
[3/5][1250/1583]        Loss_D: 0.9857  Loss_G: 1.3051  D(x): 0.4300    D(G(z)): 0.0338 / 0.3216
[3/5][1300/1583]        Loss_D: 0.5129  Loss_G: 2.0869  D(x): 0.7038    D(G(z)): 0.1073 / 0.1573
[3/5][1350/1583]        Loss_D: 0.6969  Loss_G: 1.8926  D(x): 0.7137    D(G(z)): 0.2508 / 0.1960
[3/5][1400/1583]        Loss_D: 0.6847  Loss_G: 1.1432  D(x): 0.6320    D(G(z)): 0.1387 / 0.3572
[3/5][1450/1583]        Loss_D: 0.7086  Loss_G: 1.9519  D(x): 0.7067    D(G(z)): 0.2623 / 0.1694
[3/5][1500/1583]        Loss_D: 1.5900  Loss_G: 0.7564  D(x): 0.2720    D(G(z)): 0.0319 / 0.5317
[3/5][1550/1583]        Loss_D: 0.5878  Loss_G: 1.9759  D(x): 0.6829    D(G(z)): 0.1394 / 0.1718
[4/5][0/1583]   Loss_D: 1.2915  Loss_G: 4.6033  D(x): 0.9326    D(G(z)): 0.6473 / 0.0147
[4/5][50/1583]  Loss_D: 0.5628  Loss_G: 2.6636  D(x): 0.8021    D(G(z)): 0.2602 / 0.0878
[4/5][100/1583] Loss_D: 0.7918  Loss_G: 3.5693  D(x): 0.8590    D(G(z)): 0.4298 / 0.0400
[4/5][150/1583] Loss_D: 0.7076  Loss_G: 1.4158  D(x): 0.6032    D(G(z)): 0.1286 / 0.2896
[4/5][200/1583] Loss_D: 0.9885  Loss_G: 4.2181  D(x): 0.8642    D(G(z)): 0.5119 / 0.0220
[4/5][250/1583] Loss_D: 0.5747  Loss_G: 2.3627  D(x): 0.7413    D(G(z)): 0.1956 / 0.1242
[4/5][300/1583] Loss_D: 1.2252  Loss_G: 3.8262  D(x): 0.8949    D(G(z)): 0.6160 / 0.0326
[4/5][350/1583] Loss_D: 0.5965  Loss_G: 2.7654  D(x): 0.8323    D(G(z)): 0.2956 / 0.0843
[4/5][400/1583] Loss_D: 0.5151  Loss_G: 3.3570  D(x): 0.9149    D(G(z)): 0.3180 / 0.0475
[4/5][450/1583] Loss_D: 0.6096  Loss_G: 2.0857  D(x): 0.7131    D(G(z)): 0.1878 / 0.1614
[4/5][500/1583] Loss_D: 0.6638  Loss_G: 2.2867  D(x): 0.7521    D(G(z)): 0.2703 / 0.1239
[4/5][550/1583] Loss_D: 0.7364  Loss_G: 1.5381  D(x): 0.6093    D(G(z)): 0.1566 / 0.2562
[4/5][600/1583] Loss_D: 0.4541  Loss_G: 2.4904  D(x): 0.7480    D(G(z)): 0.1169 / 0.1153
[4/5][650/1583] Loss_D: 0.4797  Loss_G: 2.5566  D(x): 0.8418    D(G(z)): 0.2380 / 0.0958
[4/5][700/1583] Loss_D: 0.8636  Loss_G: 4.6097  D(x): 0.9460    D(G(z)): 0.5171 / 0.0141
[4/5][750/1583] Loss_D: 0.9391  Loss_G: 0.8470  D(x): 0.5283    D(G(z)): 0.1705 / 0.4711
[4/5][800/1583] Loss_D: 0.9065  Loss_G: 1.2562  D(x): 0.5222    D(G(z)): 0.1288 / 0.3211
[4/5][850/1583] Loss_D: 0.5825  Loss_G: 2.9573  D(x): 0.8296    D(G(z)): 0.2920 / 0.0677
[4/5][900/1583] Loss_D: 0.8524  Loss_G: 1.5152  D(x): 0.5367    D(G(z)): 0.1316 / 0.2671
[4/5][950/1583] Loss_D: 0.5136  Loss_G: 2.5986  D(x): 0.7812    D(G(z)): 0.2083 / 0.0969
[4/5][1000/1583]        Loss_D: 0.6139  Loss_G: 2.3941  D(x): 0.7550    D(G(z)): 0.2420 / 0.1131
[4/5][1050/1583]        Loss_D: 0.5659  Loss_G: 2.7736  D(x): 0.7934    D(G(z)): 0.2486 / 0.0812
[4/5][1100/1583]        Loss_D: 1.2266  Loss_G: 3.9759  D(x): 0.9701    D(G(z)): 0.6483 / 0.0324
[4/5][1150/1583]        Loss_D: 0.7776  Loss_G: 3.4581  D(x): 0.9172    D(G(z)): 0.4394 / 0.0438
[4/5][1200/1583]        Loss_D: 1.7693  Loss_G: 5.6766  D(x): 0.9839    D(G(z)): 0.7510 / 0.0068
[4/5][1250/1583]        Loss_D: 0.8701  Loss_G: 1.5499  D(x): 0.4913    D(G(z)): 0.0567 / 0.2734
[4/5][1300/1583]        Loss_D: 0.5922  Loss_G: 2.0101  D(x): 0.7149    D(G(z)): 0.1728 / 0.1750
[4/5][1350/1583]        Loss_D: 0.6841  Loss_G: 2.7521  D(x): 0.7518    D(G(z)): 0.2826 / 0.0826
[4/5][1400/1583]        Loss_D: 0.6020  Loss_G: 2.9772  D(x): 0.8505    D(G(z)): 0.3311 / 0.0622
[4/5][1450/1583]        Loss_D: 0.9551  Loss_G: 2.5080  D(x): 0.7079    D(G(z)): 0.4012 / 0.1162
[4/5][1500/1583]        Loss_D: 0.5961  Loss_G: 1.5754  D(x): 0.6590    D(G(z)): 0.1153 / 0.2510
[4/5][1550/1583]        Loss_D: 0.5623  Loss_G: 2.9591  D(x): 0.8391    D(G(z)): 0.2910 / 0.0667

结果|

最后,让我们来看看我们是如何做的。 在这里,我们将看看三个不同的结果。 首先,我们将看到D和G在训练期间的损失是如何变化的。 其次,我们将在每个纪fixed_noise批次上可视化 G 的输出。 第三,我们将查看来自 G 的一批假数据旁边的一批真实数据。

损失与训练迭代

下面是 D & G 与训练迭代损失的图解。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
../_images/sphx_glr_dcgan_faces_tutorial_002.png

G 进展的可视化

请记住,在每个训练周期之后,我们如何将发电机的输出保存在fixed_noise批次上。 现在,我们可以用动画可视化 G 的训练进度。 按播放按钮开始动画。

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
../_images/sphx_glr_dcgan_faces_tutorial_003.png

真实图像与 假图像

最后,让我们并排看一些真实图像和假图像。

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
../_images/sphx_glr_dcgan_faces_tutorial_004.png

下一步将转到何处|

我们到达了旅程的终点,但有几个地方你可以从这里出发。 您可以:

  • 训练更长时间,看看结果有多好

  • 修改此模型以采用不同的数据集,并可能更改图像和模型体系结构的大小

  • 在此处查看其他一些很酷的 GAN项目

  • 创建生成音乐的 GAN

脚本总运行时间: (270分59.673秒)

由狮身人面像库生成的画廊