Pytorch常用函数功能使用(一)
1. view
import torch
number_1 = torch.randn(2, 3)
print(number_1)
print(number_1.shape)
print(number_1.view(1, -1))
print(number_1.view(3, -1))
输出:
tensor([[ 1.0506, -0.5875, -1.2477],
[ 0.0635, 0.8997, 0.1551]])
torch.Size([2, 3])
tensor([[ 1.0506, -0.5875, -1.2477, 0.0635, 0.8997, 0.1551]])
tensor([[ 1.0506, -0.5875],
[-1.2477, 0.0635],
[ 0.8997, 0.1551]])
View(a,b)中第一个参数a代表目标张量的行数,b代表列数。为了简便起见,也可以只指定第一个参数a,b这个参数设置成-1,函数会自动计算对应的列数。
2. squeeze
number_2 = torch.randn(2, 1)
print(number_2)
print(torch.squeeze(number_2))
print(torch.squeeze(number_2, 0))
print(torch.squeeze(number_2, 1))
输出:
tensor([[ 0.5856],
[-1.7095]])
tensor([ 0.5856, -1.7095])
tensor([[ 0.5856],
[-1.7095]])
tensor([ 0.5856, -1.7095])
Squeeze的功能是进行维度缩减(维度为1的删除)。Squeeze(a,b)中第一个参数a代表传入的张量,b代表要缩减的维数。如果第二个参数没有指定,则默认删除所有维度为1的维度
number_3 = torch.randn(1, 2)
print(number_3)
print(torch.squeeze(number_3))
print(torch.squeeze(number_3, 0))
print(torch.squeeze(number_3, 1))
输出:
tensor([[ 0.1555, -0.4286]])
tensor([ 0.1555, -0.4286])
tensor([ 0.1555, -0.4286])
tensor([[ 0.1555, -0.4286]])
3. unsqueeze
number_4 = torch.randn(3, 2)
print(number_4)
print(torch.unsqueeze(number_4, 0))
print(torch.unsqueeze(number_4, 1))
输出:
tensor([[ 0.0358, -0.2769],
[-0.3257, 0.1895],
[ 1.9278, -0.9444]])
tensor([[[ 0.0358, -0.2769],
[-0.3257, 0.1895],
[ 1.9278, -0.9444]]])
tensor([[[ 0.0358, -0.2769]],
[[-0.3257, 0.1895]],
[[ 1.9278, -0.9444]]])
Unsqueeze的功能与squeeze相反,可以增加张量的维度。Unqueeze(a,b)中第一个参数a代表传入的张量,b代表要增加维度的维数。
4. max
number_5 = torch.randn(2, 3)
print(number_5)
print(torch.max(number_5, 0))
print(torch.max(number_5, 1))
输出:
tensor([[-0.4916, 1.3999, 1.0527],
[ 1.0194, -2.4695, -0.2378]])
(tensor([1.0194, 1.3999, 1.0527]), tensor([1, 0, 0]))
(tensor([1.3999, 1.0194]), tensor([1, 0]))
Max的功能是返回对应维度最大的数以及对应的索引。Max(a,b)中第一个参数a代表传入的张量,b代表要对应的维数。0代表返回每一列的最大值,1代表返回每一行的最大值。
还没有评论,来说两句吧...