Pytorch Tensor维度变换 红太狼 2022-01-07 05:55 231阅读 0赞 1. 改变shape torch.reshape()、torch.view()可以调整Tensor的shape,返回一个新shape的Tensor,torch.view()是老版本的实现,torch.reshape()是最新的实现,两者在功能上是一样的。 示例代码: import torch a = torch.rand(4, 1, 28, 28) print(a.shape) print(a.view(4 * 1, 28, 28).shape) print(a.reshape(4 * 1, 28, 28).shape) print(a.reshape(4, 1 * 28 * 28).shape) 输出结果: torch.Size([4, 1, 28, 28]) torch.Size([4, 28, 28]) torch.Size([4, 28, 28]) torch.Size([4, 784]) 注意:维度变换的时候要注意实际意义。 2. 增加维度 torch.unsqueeze(index)可以为Tensor增加一个维度,增加的这一个维度的位置由我们自己定义,新增加的这一个维度不会改变数据本身,只是为数据新增加了一个组别,这个组别是什么由我们自己定义。 比如定义了一个Tensor: a = torch.randn(4, 1, 28, 28) 这个Tensor有4个维度,我们可以在现有维度的基础上插入一个新的维度,插入维度的index在\[-a.dim()-1, a.dim()+1\]范围内,并且当index>=0,则在index前面插入这个新增加的维度;当index < 0,则在index后面插入这个新增的维度。 示例代码: print(a.shape) print(a.unsqueeze(0).shape) print(a.unsqueeze(-1).shape) print(a.unsqueeze(3).shape) print(a.unsqueeze(4).shape) print(a.unsqueeze(-4).shape) print(a.unsqueeze(-5).shape) print(a.unsqueeze(5).shape) 输出结果: torch.Size([4, 1, 28, 28]) torch.Size([1, 4, 1, 28, 28]) torch.Size([4, 1, 28, 28, 1]) torch.Size([4, 1, 28, 1, 28]) torch.Size([4, 1, 28, 28, 1]) torch.Size([4, 1, 1, 28, 28]) torch.Size([1, 4, 1, 28, 28]) Traceback (most recent call last): File "/home/lhy/workspace/mmdetection/my_code/pytorch_ws/tensotr_0.py", line 218, in <module> print(a.unsqueeze(5).shape) IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5) 在执行a.unsqueeze(5)时报错,是因为超出了index的范围。 3. 删减维度 删减维度实际上是一个维度挤压的过程,直观地看是把那些多余的`[]`给去掉,也就是只是去删除那些size=1的维度。 import torch a = torch.Tensor(1, 4, 1, 9) print(a.shape) print(a.squeeze().shape) # 删除所有的size=1的维度 print(a.squeeze(0).shape) # 删除0号维度,ok print(a.squeeze(2).shape) # 删除2号维度,ok print(a.squeeze(3).shape) # 删除3号维度,但是3号维度是9不是1,删除失败 输出结果: torch.Size([1, 4, 1, 9]) torch.Size([4, 9]) torch.Size([4, 1, 9]) torch.Size([1, 4, 9]) torch.Size([1, 4, 1, 9]) 4. 维度扩展 expand就是**在某个size=1的维度上改变size**,改成更大的一个大小,实际就是在每个size=1的维度上的标量的广播操作。 实际用例: 我们有一个shape=\[4, 32, 14,14\]的Tensor data,相当于4张图片,每张图片32个通道,每个通道行为14,列为14的图像数据,需要将每个通道上的所有像素增加一个偏置bias。图像数据的channel=32,因此bias = torch.rand(32),但是还不能完成data + bias的操作,因为两者dim与shape不一致。为了使得dim一致,需要增加bias的维度到4维,这就用到了unsqueeze()函数;为了使得shape一致,需要bias的4个维度的shape=\[4, 32, 14, 14\],这就用到了维度扩展expand。 代码实现: import torch bias = torch.rand(32) data = torch.rand(4, 32, 14, 14) # 想要把bias加到data上面去 # 先进行维度增加 bias = bias.unsqueeze(1).unsqueeze(2).unsqueeze(0) print(bias.shape) # 再进行维度扩展 bias = bias.expand(4, -1, 14, 14) # -1表示这个维度保持不变,这里写32也可以 print(bias.shape) data + bias 输出结果: torch.Size([1, 32, 1, 1]) torch.Size([4, 32, 14, 14]) 5. 维度重复 repeat就是**将每个位置的维度都重复至指定的次数**,以形成新的Tensor,功能与维度扩展一样,但是repeat会重新申请内存空间,repeat()参数表示各个维度指定的重复次数。 代码示例: import torch b = torch.Tensor(1, 32, 1, 1) print(b.shape) # 维度重复,32这里不想进行重复,所以就相当于"重复至1次" b = b.repeat(4, 1, 14, 14) print(b.shape) 输出结果: torch.Size([1, 32, 1, 1]) torch.Size([4, 32, 14, 14]) 6. 转置操作 Pytorch的转置操作只适用于dim=2的Tensor,也就是矩阵。 示例代码: c = torch.Tensor(2, 4) print(c.t().shape) 输出结果: torch.Size([4, 2]) 7. 维度变换 (1) transpose(dim1, dim2)交换dim1与dim2 注意这种交换使得存储不再连续,再执行一些reshape的操作会报错,所以要调用一下`contiguous()`使其变成连续的维度。 示例代码: d = torch.Tensor(6, 3, 1, 2) # 1号维度和3号维度交换 print(d.transpose(1, 3).contiguous().shape) 输出结果: torch.Size([6, 2, 1, 3]) 下面这个例子比较一下每个位置上的元素都是一致的,来验证一下这个交换->压缩shape->展开shape->交换回去是没有问题的。 e = torch.rand(4, 3, 6, 7) e2 = e.transpose(1, 3).contiguous().reshape(4, 7 * 6 * 3).reshape(4, 7, 6, 3).transpose(1, 3) print(e2.shape) # 比较下两个Tensor所有位置上的元素是否都相等,返回1表示相等。 print(torch.all(torch.eq(e, e2))) 输出结果: torch.Size([4, 3, 6, 7]) tensor(1, dtype=torch.uint8) (2) permute 如果四个维度表示上节的\[batch,channel,h,w\] ,如果想把channel放到最后去,形成\[batch,h,w,channel\],那么如果使用前面的维度交换,至少要交换两次(先13交换再12交换)。而使用permute可以直接指定维度新的所处位置,更加方便。 示例代码: a = torch.rand(4, 3, 6, 7) print(a.permute(0, 2, 3, 1).shape) 输出结果: torch.Size([4, 6, 7, 3])
还没有评论,来说两句吧...