1.图片分类网络
这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片
2.图片生成网络
输入是一个随机噪声,输出是一张图片,使用的是反卷积层
相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了
首先是图片分类网络:
简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
class Discriminator(nn.Module): def __init__( self , ngpu): super (Discriminator, self ).__init__() self .ngpu = ngpu self .main = nn.Sequential( # input is (nc) x 64 x 64 nn.Conv2d(nc, ndf, 4 , 2 , 1 , bias = False ), nn.LeakyReLU( 0.2 , inplace = True ), # state size. (ndf) x 32 x 32 nn.Conv2d(ndf, ndf * 2 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d(ndf * 2 ), nn.LeakyReLU( 0.2 , inplace = True ), # state size. (ndf*2) x 16 x 16 nn.Conv2d(ndf * 2 , ndf * 4 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d(ndf * 4 ), nn.LeakyReLU( 0.2 , inplace = True ), # state size. (ndf*4) x 8 x 8 nn.Conv2d(ndf * 4 , ndf * 8 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d(ndf * 8 ), nn.LeakyReLU( 0.2 , inplace = True ), # state size. (ndf*8) x 4 x 4 nn.Conv2d(ndf * 8 , 1 , 4 , 1 , 0 , bias = False ), nn.Sigmoid() ) def forward( self , input ): return self .main( input ) |
重点是生成网络
代码如下,其实就是反卷积+bn+relu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
class Generator(nn.Module): def __init__( self , ngpu): super (Generator, self ).__init__() self .ngpu = ngpu self .main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d( nz, ngf * 8 , 4 , 1 , 0 , bias = False ), nn.BatchNorm2d(ngf * 8 ), nn.ReLU( True ), # state size. (ngf*8) x 4 x 4 nn.ConvTranspose2d(ngf * 8 , ngf * 4 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d(ngf * 4 ), nn.ReLU( True ), # state size. (ngf*4) x 8 x 8 nn.ConvTranspose2d( ngf * 4 , ngf * 2 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d(ngf * 2 ), nn.ReLU( True ), # state size. (ngf*2) x 16 x 16 nn.ConvTranspose2d( ngf * 2 , ngf, 4 , 2 , 1 , bias = False ), nn.BatchNorm2d(ngf), nn.ReLU( True ), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d( ngf, nc, 4 , 2 , 1 , bias = False ), nn.Tanh() # state size. (nc) x 64 x 64 ) def forward( self , input ): return self .main( input ) |
讲道理,以上两个网络都挺简单。
真正的重点到了,怎么训练
每一个step分为三个步骤:
-
训练二分类网络
1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络 -
训练生成网络
3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络
不多说直接上代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
for epoch in range (num_epochs): # For each batch in the dataloader for i, data in enumerate (dataloader, 0 ): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### ## Train with all-real batch netD.zero_grad() # Format batch real_cpu = data[ 0 ].to(device) b_size = real_cpu.size( 0 ) label = torch.full((b_size,), real_label, device = device) # Forward pass real batch through D output = netD(real_cpu).view( - 1 ) # Calculate loss on all-real batch errD_real = criterion(output, label) # Calculate gradients for D in backward pass errD_real.backward() D_x = output.mean().item() ## Train with all-fake batch # Generate batch of latent vectors noise = torch.randn(b_size, nz, 1 , 1 , device = device) # Generate fake image batch with G fake = netG(noise) label.fill_(fake_label) # Classify all fake batch with D output = netD(fake.detach()).view( - 1 ) # Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) # Calculate the gradients for this batch errD_fake.backward() D_G_z1 = output.mean().item() # Add the gradients from the all-real and all-fake batches errD = errD_real + errD_fake # Update D optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view( - 1 ) # Calculate G's loss based on this output errG = criterion(output, label) # Calculate gradients for G errG.backward() D_G_z2 = output.mean().item() # Update G optimizerG.step() # Output training stats if i % 50 = = 0 : print ( '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len (dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # Save Losses for plotting later G_losses.append(errG.item()) D_losses.append(errD.item()) # Check how the generator is doing by saving G's output on fixed_noise if (iters % 500 = = 0 ) or ((epoch = = num_epochs - 1 ) and (i = = len (dataloader) - 1 )): with torch.no_grad(): fake = netG(fixed_noise).detach().cpu() img_list.append(vutils.make_grid(fake, padding = 2 , normalize = True )) iters + = 1 |
以上就是Pytorch学习笔记DCGAN极简入门教程的详细内容,更多关于Pytorch学习DCGAN入门教程的资料请关注服务器之家其它相关文章!
原文链接:https://blog.csdn.net/xz1308579340/article/details/105883090