服务器之家

服务器之家 > 正文

pytorch 带batch的tensor类型图像显示操作

时间:2021-11-08 10:41     来源/作者:Xavier Jiezou

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

?
1
pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[b, c, h, w],我们每次只拿一个数据出来就是[c, h, w],而matplotlib.pyplot.imshow要求的输入维度是[h, w, c],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

?
1
2
3
4
5
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.size([3, 5, 2])

代码示例

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import dataloader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.mnist(
    root='./dataset/',
    train=true,
    transform=transforms.compose([
        transforms.totensor(),
        transforms.normalize((0.1307,), (0.3081,))
    ]),
    download=true
)
#%% 制作数据加载器
train_loader = dataloader(
    dataset=train_file,
    batch_size=9,
    shuffle=true
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:pil,plt显示tensor类型的图像

该方法针对显示dataloader读取的图像

pil 与plt中对应操作不同,但原理是一样的,我试过用下方代码image的方法在plt上show失败了,原因暂且不知。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 方法1:image.show()
# transforms.topilimage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-d tensor,所以要用image[0]消去batch那一维
img = transforms.topilimage(image[0])
img.show()
 
# 方法2:plt.imshow(ndarray)
img = image[0] # plt.imshow()只能接受3-d tensor,所以也要用image[0]消去batch那一维
img = img.numpy() # floattensor转为ndarray
img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
# 显示图片
plt.imshow(img)
plt.show()
cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/qq_42951560/article/details/109962828

标签:

相关文章

热门资讯

yue是什么意思 网络流行语yue了是什么梗
yue是什么意思 网络流行语yue了是什么梗 2020-10-11
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全 2019-12-26
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
2021年耽改剧名单 2021要播出的59部耽改剧列表
2021年耽改剧名单 2021要播出的59部耽改剧列表 2021-03-05
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总 2020-11-13
返回顶部