先说结论
model.state_dict()
是浅拷贝,返回的参数仍然会随着网络的训练而变化。
应该使用deepcopy(model.state_dict())
,或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
其实看了如下代码的输出应该就懂了
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
import torch import torch.nn as nn import torchvision import numpy as np from torchsummary import summary # Define model class TheModelClass(nn.Module): def __init__( self ): super (TheModelClass, self ).__init__() self .conv1 = nn.Conv2d( 3 , 6 , 5 ) self .pool = nn.MaxPool2d( 2 , 2 ) self .conv2 = nn.Conv2d( 6 , 16 , 5 ) self .fc1 = nn.Linear( 16 * 5 * 5 , 120 ) self .fc2 = nn.Linear( 120 , 84 ) self .fc3 = nn.Linear( 84 , 10 ) def forward( self , x): x = self .pool(F.relu( self .conv1(x))) x = self .pool(F.relu( self .conv2(x))) x = x.view( - 1 , 16 * 5 * 5 ) x = F.relu( self .fc1(x)) x = F.relu( self .fc2(x)) x = self .fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = torch.optim.SGD(model.parameters(), lr = 0.001 , momentum = 0.9 ) # Print model's state_dict print ( "Model's state_dict:" ) for param_tensor in model.state_dict(): print (param_tensor, "\t" , model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print ( "Optimizer's state_dict:" ) for var_name in optimizer.state_dict(): print (var_name, "\t" , optimizer.state_dict()[var_name]) |
输出如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
Model's state_dict: conv1.weight torch.Size([ 6 , 3 , 5 , 5 ]) conv1.bias torch.Size([ 6 ]) conv2.weight torch.Size([ 16 , 6 , 5 , 5 ]) conv2.bias torch.Size([ 16 ]) fc1.weight torch.Size([ 120 , 400 ]) fc1.bias torch.Size([ 120 ]) fc2.weight torch.Size([ 84 , 120 ]) fc2.bias torch.Size([ 84 ]) fc3.weight torch.Size([ 10 , 84 ]) fc3.bias torch.Size([ 10 ]) Optimizer's state_dict: state {} param_groups [{ 'lr' : 0.001 , 'momentum' : 0.9 , 'dampening' : 0 , 'weight_decay' : 0 , 'nesterov' : False , 'params' : [ 2238501264336 , 2238501329800 , 2238501330016 , 2238501327136 , 2238501328576 , 2238501329728 , 2238501327928 , 2238501327064 , 2238501330808 , 2238501328288 ]}] |
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***object has no attribute 'state_dict'
定义了一个类BaseNet并实例化该类:
1
|
net = BaseNet() |
保存net时报错 object has no attribute 'state_dict'
1
|
torch.save(net.state_dict(), models_dir) |
原因是定义类的时候不是继承nn.Module类,比如:
1
2
|
class BaseNet( object ): def __init__( self ): |
把类定义改为
1
2
3
|
class BaseNet(nn.Module): def __init__( self ): super (BaseNet, self ).__init__() |
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。如有错误或未考虑完全的地方,望不吝赐教。
原文链接:https://www.cnblogs.com/LukeStepByStep/p/11248361.html