服务器之家

服务器之家 > 正文

pytorch 带batch的tensor类型图像显示操作

时间:2021-11-08 10:41     来源/作者:Xavier Jiezou

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

?
1
pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[b, c, h, w],我们每次只拿一个数据出来就是[c, h, w],而matplotlib.pyplot.imshow要求的输入维度是[h, w, c],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

?
1
2
3
4
5
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.size([3, 5, 2])

代码示例

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import dataloader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.mnist(
    root='./dataset/',
    train=true,
    transform=transforms.compose([
        transforms.totensor(),
        transforms.normalize((0.1307,), (0.3081,))
    ]),
    download=true
)
#%% 制作数据加载器
train_loader = dataloader(
    dataset=train_file,
    batch_size=9,
    shuffle=true
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:pil,plt显示tensor类型的图像

该方法针对显示dataloader读取的图像

pil 与plt中对应操作不同,但原理是一样的,我试过用下方代码image的方法在plt上show失败了,原因暂且不知。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 方法1:image.show()
# transforms.topilimage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-d tensor,所以要用image[0]消去batch那一维
img = transforms.topilimage(image[0])
img.show()
 
# 方法2:plt.imshow(ndarray)
img = image[0] # plt.imshow()只能接受3-d tensor,所以也要用image[0]消去batch那一维
img = img.numpy() # floattensor转为ndarray
img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
# 显示图片
plt.imshow(img)
plt.show()
cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/qq_42951560/article/details/109962828

标签:

相关文章

热门资讯

yue是什么意思 网络流行语yue了是什么梗
yue是什么意思 网络流行语yue了是什么梗 2020-10-11
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全 2019-12-26
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
2021年耽改剧名单 2021要播出的59部耽改剧列表
2021年耽改剧名单 2021要播出的59部耽改剧列表 2021-03-05
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总 2020-11-13
返回顶部

1523
Weibo Article 1 Weibo Article 2 Weibo Article 3 Weibo Article 4 Weibo Article 5 Weibo Article 6 Weibo Article 7 Weibo Article 8 Weibo Article 9 Weibo Article 10 Weibo Article 11 Weibo Article 12 Weibo Article 13 Weibo Article 14 Weibo Article 15 Weibo Article 16 Weibo Article 17 Weibo Article 18 Weibo Article 19 Weibo Article 20 Weibo Article 21 Weibo Article 22 Weibo Article 23 Weibo Article 24 Weibo Article 25 Weibo Article 26 Weibo Article 27 Weibo Article 28 Weibo Article 29 Weibo Article 30 Weibo Article 31 Weibo Article 32 Weibo Article 33 Weibo Article 34 Weibo Article 35 Weibo Article 36 Weibo Article 37 Weibo Article 38 Weibo Article 39 Weibo Article 40