PyTorch入门学习 3-索引 index_select, gather

avatar 2024年04月14日19:30:20 0 58 views
博主分享免费Java教学视频,B站账号:Java刘哥

代码

import torch
print('# 1、索引选择,一维矩阵')
x = torch.arange(0, 5, 1)  # 从0到5,步长为1。等价于 x = torch.tensor([0, 1, 2, 3, 4])
print(x)
print(x[2], '\n')  # 选择第3个元素

print('# 2、索引选择, 二维矩阵')
x = torch.arange(9).reshape(3, 3)  # 从0到9,步长为1,重塑为3x3的矩阵。reshape可以改成view
# 等价于 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6,7,8]])
print(x, '\n')
print('# 选择第2行第2列的元素', x[1, 1], '\n')

print('# 选择第2行的元素', x[1, :])
print('# 选择第2列的元素', x[:, 1], '\n')

print('# 选择最后一行的元素', x[-1, :])
print('# 选择最后一列的元素', x[:, -1], '\n')

print('# 选择<2行的元素', x[:1])
print('# 选择>=2行元素', x[1:], '\n')

print('# 选择<2列的元素', x[:, :1])
print('# 选择>=2列的元素', x[:, 1:], '\n')
print('# 选择第1行第1列到第2行第2列的元素', x[0:2, 0:2], '\n')

print('# 3、索引选择, 布尔选择')
x = torch.tensor([1, 2, 3, 4, 5, 6, 7])
print(x[x > 3], '\n')  # 选择大于3的元素
print(x[x % 2 == 0], '\n')  # 选择偶数元素

print('# 4、索引选择, index_select')  # 返回一个新的张量,该张量使用 LongTensor 索引中的条目沿维度 dim 对输入张量进行索引。
x = torch.arange(9).view(3, 3)
print(x)
indices = torch.tensor([0, 1])  # 选择第0行,第1行
print(indices)
print(torch.index_select(x, 0, indices))  # 选择第0行,第1行,dim=0 表示行
print(torch.index_select(x, 1, indices), '\n')  # 选择第0列和第1列,dim=1 表示列
#
print('# 5、索引选择, masked_select')  # 返回一个新的一维张量,该张量根据布尔掩码(即 BoolTensor)对输入张量进行索引。
x = torch.arange(9).view(3, 3)
print(x)
print(x.ge(5))  # 判断每个元素是否>=5
print(torch.masked_select(x, x.ge(5)), '\n')  # 选择>=5的元素,返回一维矩阵

print('# 6、索引选择, masked_select')  # 返回一个新的一维张量,该张量根据布尔掩码(即 BoolTensor)对输入张量进行索引。
x = torch.arange(9).view(3, 3)
print(x)
print(x.ge(5))  # 判断每个元素是否>=5
print(torch.masked_select(x, x.ge(5)), '\n')  # 选择>=5的元素,返回一维矩阵

print('# 7、索引选择, nonzero')  # 返回一个包含输入张量 x 中非零元素索引的张量。
x = torch.tensor([0, 1, 0, 1, 0, 0])
print(x)
print(torch.nonzero(x))  # 返回非0元素的索引
print(torch.nonzero(x, as_tuple=True), '\n')  # 返回非0元素的索引,返回元组
#
print('# 8、索引选择, gather')
input = torch.tensor([[1, 2], [3, 4]])
# 1 2
# 3 4
print(input, '\n')
# dim = 1 (按行索引)表示从input中的每一行中,选择第index列对应的元素,组成一个新的tensor
# index=[0, 0], [1, 0] 表示取第一行的第0个,第0个。即 1 1。取第二行的第1个,第0个。即 4 3
x = torch.gather(input, 1, torch.tensor([[0, 0], [1, 0]]))
print(x, '\n')

# dim = 0 (按列索引)表示从input中的每一列中,选择第index行对应的元素,组成一个新的tensor
# index=[0, 0] 表示第一列第0个,第二列第0个
# index=[1, 0] 表示第一列第1个,第二列第0个
# 换言之,按列索引,先找第一列的的第0个、第1个;再找第二列的第0个、第0个
x = torch.gather(input, 0, torch.tensor([[0, 0], [1, 0]]))
print(x, '\n')

运行结果

# 1、索引选择,一维矩阵
tensor([0, 1, 2, 3, 4])
tensor(2) 

# 2、索引选择, 二维矩阵
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]) 

# 选择第2行第2列的元素 tensor(4) 

# 选择第2行的元素 tensor([3, 4, 5])
# 选择第2列的元素 tensor([1, 4, 7]) 

# 选择最后一行的元素 tensor([6, 7, 8])
# 选择最后一列的元素 tensor([2, 5, 8]) 

# 选择<2行的元素 tensor([[0, 1, 2]])
# 选择>=2行元素 tensor([[3, 4, 5],
        [6, 7, 8]]) 

# 选择<2列的元素 tensor([[0],
        [3],
        [6]])
# 选择>=2列的元素 tensor([[1, 2],
        [4, 5],
        [7, 8]]) 

# 选择第1行第1列到第2行第2列的元素 tensor([[0, 1],
        [3, 4]]) 

# 3、索引选择, 布尔选择
tensor([4, 5, 6, 7]) 

tensor([2, 4, 6]) 

# 4、索引选择, index_select
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([0, 1])
tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0, 1],
        [3, 4],
        [6, 7]]) 

# 5、索引选择, masked_select
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])
tensor([5, 6, 7, 8]) 

# 6、索引选择, masked_select
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])
tensor([5, 6, 7, 8]) 

# 7、索引选择, nonzero
tensor([0, 1, 0, 1, 0, 0])
tensor([[1],
        [3]])
(tensor([1, 3]),) 

# 8、索引选择, gather
tensor([[1, 2],
        [3, 4]]) 

tensor([[1, 1],
        [4, 3]]) 

tensor([[1, 2],
        [3, 2]]) 

 

关于 gather 还需要理解。

目前记住dim=1,按行索引,然后依次逐行找index里的每个元素

比如 dim=1, index=[[1,2], [3,4]] 表示在第一行找第1个和第2个;在第二行找第3个和第4个(注意:阿拉伯数字下标从0开始)

 

关于 dim=0的,不知道我有没有理解错,后面深入学习后再来补充或修改

 

  • 微信
  • 交流学习,有偿服务
  • weinxin
  • 博客/Java交流群
  • 资源分享,问题解决,技术交流。群号:590480292
  • weinxin
avatar

发表评论

avatar 登录者:匿名
匿名评论,评论回复后会有邮件通知

  

已通过评论:0   待审核评论数:0