WGAN与GAN的不同
- 去除sigmoid
- 使用具有动量的优化方法,比如使用RMSProp
- 要对Discriminator的权重做修整限制以确保lipschitz连续约
WGAN实战卷积生成动漫头像
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
|
import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.utils import save_image import os from anime_face_generator.dataset import ImageDataset batch_size = 32 num_epoch = 100 z_dimension = 100 dir_path = './wgan_img' # 创建文件夹 if not os.path.exists(dir_path): os.mkdir(dir_path) def to_img(x): """因为我们在生成器里面用了tanh""" out = 0.5 * (x + 1 ) return out dataset = ImageDataset() dataloader = DataLoader(dataset, batch_size = 32 , shuffle = False ) class Generator(nn.Module): def __init__( self ): super ().__init__() self .gen = nn.Sequential( # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map nn.ConvTranspose2d( 100 , 512 , 4 , 1 , 0 , bias = False ), nn.BatchNorm2d( 512 ), nn.ReLU( True ), # 上一步的输出形状:(512) x 4 x 4 nn.ConvTranspose2d( 512 , 256 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d( 256 ), nn.ReLU( True ), # 上一步的输出形状: (256) x 8 x 8 nn.ConvTranspose2d( 256 , 128 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d( 128 ), nn.ReLU( True ), # 上一步的输出形状: (256) x 16 x 16 nn.ConvTranspose2d( 128 , 64 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d( 64 ), nn.ReLU( True ), # 上一步的输出形状:(256) x 32 x 32 nn.ConvTranspose2d( 64 , 3 , 5 , 3 , 1 , bias = False ), nn.Tanh() # 输出范围 -1~1 故而采用Tanh # nn.Sigmoid() # 输出形状:3 x 96 x 96 ) def forward( self , x): x = self .gen(x) return x def weight_init(m): # weight_initialization: important for wgan class_name = m.__class__.__name__ if class_name.find( 'Conv' ) ! = - 1 : m.weight.data.normal_( 0 , 0.02 ) elif class_name.find( 'Norm' ) ! = - 1 : m.weight.data.normal_( 1.0 , 0.02 ) class Discriminator(nn.Module): def __init__( self ): super ().__init__() self .dis = nn.Sequential( nn.Conv2d( 3 , 64 , 5 , 3 , 1 , bias = False ), nn.LeakyReLU( 0.2 , inplace = True ), # 输出 (64) x 32 x 32 nn.Conv2d( 64 , 128 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d( 128 ), nn.LeakyReLU( 0.2 , inplace = True ), # 输出 (128) x 16 x 16 nn.Conv2d( 128 , 256 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d( 256 ), nn.LeakyReLU( 0.2 , inplace = True ), # 输出 (256) x 8 x 8 nn.Conv2d( 256 , 512 , 4 , 2 , 1 , bias = False ), nn.BatchNorm2d( 512 ), nn.LeakyReLU( 0.2 , inplace = True ), # 输出 (512) x 4 x 4 nn.Conv2d( 512 , 1 , 4 , 1 , 0 , bias = False ), nn.Flatten(), # nn.Sigmoid() # 输出一个数(概率) ) def forward( self , x): x = self .dis(x) return x def weight_init(m): # weight_initialization: important for wgan class_name = m.__class__.__name__ if class_name.find( 'Conv' ) ! = - 1 : m.weight.data.normal_( 0 , 0.02 ) elif class_name.find( 'Norm' ) ! = - 1 : m.weight.data.normal_( 1.0 , 0.02 ) def save(model, filename = "model.pt" , out_dir = "out/" ): if model is not None : if not os.path.exists(out_dir): os.mkdir(out_dir) torch.save({ 'model' : model.state_dict()}, out_dir + filename) else : print ( "[ERROR]:Please build a model!!!" ) import QuickModelBuilder as builder if __name__ = = '__main__' : one = torch.FloatTensor([ 1 ]).cuda() mone = - 1 * one is_print = True # 创建对象 D = Discriminator() G = Generator() D.weight_init() G.weight_init() if torch.cuda.is_available(): D = D.cuda() G = G.cuda() lr = 2e - 4 d_optimizer = torch.optim.RMSprop(D.parameters(), lr = lr, ) g_optimizer = torch.optim.RMSprop(G.parameters(), lr = lr, ) d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma = 0.99 ) g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma = 0.99 ) fake_img = None # ##########################进入训练##判别器的判断过程##################### for epoch in range (num_epoch): # 进行多个epoch的训练 pbar = builder.MyTqdm(epoch = epoch, maxval = len (dataloader)) for i, img in enumerate (dataloader): num_img = img.size( 0 ) real_img = img.cuda() # 将tensor变成Variable放入计算图中 # 这里的优化器是D的优化器 for param in D.parameters(): param.requires_grad = True # ########判别器训练train##################### # 分为两部分:1、真的图像判别为真;2、假的图像判别为假 # 计算真实图片的损失 d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0 real_out = D(real_img) # 将真实图片放入判别器中 d_loss_real = real_out.mean( 0 ).view( 1 ) d_loss_real.backward(one) # 计算生成图片的损失 z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声 z = z.reshape(num_img, z_dimension, 1 , 1 ) fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离 fake_out = D(fake_img) # 判别器判断假的图片, d_loss_fake = fake_out.mean( 0 ).view( 1 ) d_loss_fake.backward(mone) d_loss = d_loss_fake - d_loss_real d_optimizer.step() # 更新参数 # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01 for parm in D.parameters(): parm.data.clamp_( - 0.01 , 0.01 ) # ==================训练生成器============================ # ###############################生成网络的训练############################### for param in D.parameters(): param.requires_grad = False # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D g_optimizer.zero_grad() # 梯度归0 z = torch.randn(num_img, z_dimension).cuda() z = z.reshape(num_img, z_dimension, 1 , 1 ) fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片 output = D(fake_img) # 经过判别器得到的结果 # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss g_loss = torch.mean(output).view( 1 ) # bp and optimize g_loss.backward(one) # 进行反向传播 g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数 # 打印中间的损失 pbar.set_right_info(d_loss = d_loss.data.item(), g_loss = g_loss.data.item(), real_scores = real_out.data.mean().item(), fake_scores = fake_out.data.mean().item(), ) pbar.update() try : fake_images = to_img(fake_img.cpu()) save_image(fake_images, dir_path + '/fake_images-{}.png' . format (epoch + 1 )) except : pass if is_print: is_print = False real_images = to_img(real_img.cpu()) save_image(real_images, dir_path + '/real_images.png' ) pbar.finish() d_scheduler.step() g_scheduler.step() save(D, "wgan_D.pt" ) save(G, "wgan_G.pt" ) |
到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!
原文链接:https://blog.csdn.net/bu_fo/article/details/109808354