1
|
permute(dims) |
将tensor的维度换位。
参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。
例:
1
2
3
4
5
6
7
|
import torch import numpy as np a = np.array([[[ 1 , 2 , 3 ],[ 4 , 5 , 6 ]]]) unpermuted = torch.tensor(a) print (unpermuted.size()) # ——> torch.Size([1, 2, 3]) permuted = unpermuted.permute( 2 , 0 , 1 ) print (permuted.size()) # ——> torch.Size([3, 1, 2]) |
再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。
利用这个函数permute(1,3,2)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成
1
2
3
|
tensor([[[ 1. , 4. ], [ 2. , 5. ], [ 3. , 6. ]]]) |
如果使用view(1,3,2),可以得到
1
2
3
|
tensor([[[ 1. , 2. ], [ 3. , 4. ], [ 5. , 6. ]]]) |
以上这篇PyTorch中permute的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_40231500/article/details/90606872