看代码吧~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
import torch import torch.utils.data as Data torch.manual_seed( 1 ) # reproducible # BATCH_SIZE = 5 BATCH_SIZE = 8 # 每次使用8个数据同时传入网路 x = torch.linspace( 1 , 10 , 10 ) # this is x data (torch tensor) y = torch.linspace( 10 , 1 , 10 ) # this is y data (torch tensor) torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( dataset = torch_dataset, # torch TensorDataset format batch_size = BATCH_SIZE, # mini batch size shuffle = False , # 设置不随机打乱数据 random shuffle for training num_workers = 2 , # 使用两个进程提取数据,subprocesses for loading data ) def show_batch(): for epoch in range ( 3 ): # 全部的数据使用3遍,train entire dataset 3 times for step, (batch_x, batch_y) in enumerate (loader): # for each training step # train your data... print ( 'Epoch: ' , epoch, '| Step: ' , step, '| batch x: ' , batch_x.numpy(), '| batch y: ' , batch_y.numpy()) if __name__ = = '__main__' : show_batch() |
BATCH_SIZE = 8 , 所有数据利用三次
1
2
3
4
5
6
|
Epoch: 0 | Step: 0 | batch x: [ 1. 2. 3. 4. 5. 6. 7. 8. ] | batch y: [ 10. 9. 8. 7. 6. 5. 4. 3. ] Epoch: 0 | Step: 1 | batch x: [ 9. 10. ] | batch y: [ 2. 1. ] Epoch: 1 | Step: 0 | batch x: [ 1. 2. 3. 4. 5. 6. 7. 8. ] | batch y: [ 10. 9. 8. 7. 6. 5. 4. 3. ] Epoch: 1 | Step: 1 | batch x: [ 9. 10. ] | batch y: [ 2. 1. ] Epoch: 2 | Step: 0 | batch x: [ 1. 2. 3. 4. 5. 6. 7. 8. ] | batch y: [ 10. 9. 8. 7. 6. 5. 4. 3. ] Epoch: 2 | Step: 1 | batch x: [ 9. 10. ] | batch y: [ 2. 1. ] |
补充:pytorch批训练bug
问题描述:
在进行pytorch神经网络批训练的时候,有时会出现报错
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>
解决办法:
第一步:
检查(重点!!!!!):
1
|
train_dataset = Data.TensorDataset(train_x, train_y) |
train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable
可以这样将数据变为tensor类:
1
|
train_x = torch.FloatTensor(train_x) |
第二步:
1
2
3
4
5
|
train_loader = Data.DataLoader( dataset = train_dataset, batch_size = batch_size, shuffle = True ) |
实例化一个DataLoader对象
第三步:
1
2
3
|
for epoch in range (epochs): for step, (batch_x, batch_y) in enumerate (train_loader): batch_x, batch_y = Variable(batch_x), Variable(batch_y) |
这样就可以批训练了
需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://www.cnblogs.com/yangzhaonan/p/10439839.html