Pytorch常用函数功能使用(一)

忘是亡心i 2022-03-15 14:55 335阅读 0赞

1. view

  1. import torch
  2. number_1 = torch.randn(2, 3)
  3. print(number_1)
  4. print(number_1.shape)
  5. print(number_1.view(1, -1))
  6. print(number_1.view(3, -1))

输出:

  1. tensor([[ 1.0506, -0.5875, -1.2477],
  2. [ 0.0635, 0.8997, 0.1551]])
  3. torch.Size([2, 3])
  4. tensor([[ 1.0506, -0.5875, -1.2477, 0.0635, 0.8997, 0.1551]])
  5. tensor([[ 1.0506, -0.5875],
  6. [-1.2477, 0.0635],
  7. [ 0.8997, 0.1551]])

View(a,b)中第一个参数a代表目标张量的行数,b代表列数。为了简便起见,也可以只指定第一个参数a,b这个参数设置成-1,函数会自动计算对应的列数。

2. squeeze

  1. number_2 = torch.randn(2, 1)
  2. print(number_2)
  3. print(torch.squeeze(number_2))
  4. print(torch.squeeze(number_2, 0))
  5. print(torch.squeeze(number_2, 1))

输出:

  1. tensor([[ 0.5856],
  2. [-1.7095]])
  3. tensor([ 0.5856, -1.7095])
  4. tensor([[ 0.5856],
  5. [-1.7095]])
  6. tensor([ 0.5856, -1.7095])

Squeeze的功能是进行维度缩减(维度为1的删除)。Squeeze(a,b)中第一个参数a代表传入的张量,b代表要缩减的维数。如果第二个参数没有指定,则默认删除所有维度为1的维度

  1. number_3 = torch.randn(1, 2)
  2. print(number_3)
  3. print(torch.squeeze(number_3))
  4. print(torch.squeeze(number_3, 0))
  5. print(torch.squeeze(number_3, 1))

输出:

  1. tensor([[ 0.1555, -0.4286]])
  2. tensor([ 0.1555, -0.4286])
  3. tensor([ 0.1555, -0.4286])
  4. tensor([[ 0.1555, -0.4286]])

3. unsqueeze

  1. number_4 = torch.randn(3, 2)
  2. print(number_4)
  3. print(torch.unsqueeze(number_4, 0))
  4. print(torch.unsqueeze(number_4, 1))

输出:

  1. tensor([[ 0.0358, -0.2769],
  2. [-0.3257, 0.1895],
  3. [ 1.9278, -0.9444]])
  4. tensor([[[ 0.0358, -0.2769],
  5. [-0.3257, 0.1895],
  6. [ 1.9278, -0.9444]]])
  7. tensor([[[ 0.0358, -0.2769]],
  8. [[-0.3257, 0.1895]],
  9. [[ 1.9278, -0.9444]]])

Unsqueeze的功能与squeeze相反,可以增加张量的维度。Unqueeze(a,b)中第一个参数a代表传入的张量,b代表要增加维度的维数。

4. max

  1. number_5 = torch.randn(2, 3)
  2. print(number_5)
  3. print(torch.max(number_5, 0))
  4. print(torch.max(number_5, 1))

输出:

  1. tensor([[-0.4916, 1.3999, 1.0527],
  2. [ 1.0194, -2.4695, -0.2378]])
  3. (tensor([1.0194, 1.3999, 1.0527]), tensor([1, 0, 0]))
  4. (tensor([1.3999, 1.0194]), tensor([1, 0]))

Max的功能是返回对应维度最大的数以及对应的索引。Max(a,b)中第一个参数a代表传入的张量,b代表要对应的维数。0代表返回每一列的最大值,1代表返回每一行的最大值。

发表评论

表情:
评论列表 (有 0 条评论,335人围观)

还没有评论,来说两句吧...

相关阅读

    相关 Excel函数

    Excel常用函数(一)的表格下载链接: 文章目录 一、公式中常用的运算符和运算函数 二、函数引用方式 三、格式转换 3.1 货币符号与单位的添加 3.2 英文大小