彭某的技术折腾笔记

彭某的技术折腾笔记

PyTorch 高级索引

2024-05-16

PyTorch 高级索引

2024年5月15日

摘要

在使用 PyTorch 的过程中,常规的访问 Tensor 的索引方式并不能够很好的适应一些复杂场景。本文将介绍一些在 PyTorch 中访问 Tensor 的高级索引方式。

索引方式

假设存在一个 ​2\times 4 的 2D Tensor,第一排从 0 到 3,第二排从 4 到 7:

import torch

sample = torch.arrange(8).reshape(2, 4)
# tensor([[0, 1, 2, 3],
#         [4, 5, 6, 7]])

基础方式

访问元素

有两种方式可以进行基础的元素访问:

# Divide dimension by comma
element = sample[0, 1]
# tensor(1)

# C-Style
element = sample[0][1]
# tensor(1)

其次,也可以使用负数进行索引,-1 代表最后一个元素,-2 代表倒数第二个,以此类推:

element = sample[-1, -2]
element = sample[-1][-2]
# tensor(6)

访问切片

可以用 a:b 表示选取 [a, b) 的索引:

token = sample[0, 1:3]
token = sample[0][1:3]
# tensor([1, 2])

其中,当 a == b 时,当前维度切片长度为 0:

token = sample[0, 1:1]
token = sample[0][1:1]
# tensor([], dtype=torch.int64)

b == a + 1 时,当前维度切片长度为 1,需要注意的是,切片长度为 1 并不代表这个维度不存在了:

token = sample[0:1, 1:2]
token = sample[0:1][1:2]
# tensor([[1]])

element = sample[0, 1]
element = sample[0][1]
# tensor(1)

可以看到,使用 b == a + 1 的形式作为索引,取出的值虽然和直接用 a 做索引一致,但是维度数量能保持不变,直接使用标量 a 做索引会去除当前维度。

切片中一样可以使用负数索引代表倒数第几个元素。

[a, b) 超出 Tensor 的范围时,超出部分的索引将被忽略。

a 不写则为 1,例如 :bb 不写则为 -1,例如 a:

高级方式

间隔访问切片

除了连续访问,还能够间隔以等差数列索引来访问,格式为 start:end:step,其将生成由 start 开始(闭区间),end 结束(开区间),以 step 为步长的等差数列列表,并将这个列表内的值作为索引实现间隔取值,当 Tensor 中不存在索引列表所需的值时,此索引将被忽略。

token = sample[0, 1:100:2]
token = sample[0][1:100:2]
# tensor([1, 3])

其中,startend 可以是负数,step 不行。

其中,start 不写则为 1end 不写则为 -1step 不写则为 1

列表索引访问

除了使用 start:end:step 生成等差数列索引列表,我们还可以直接手动提供索引列表:

token = sample[1, [1, 3]]
token = sample[1][[1, 3]]
# tensor([5, 7])

Bool 访问

除了使用标量列表索引访问,我们还可以使用 Bool 列表来访问:

token = sample[1, [True, False, False, True]]
token = sample[1][[True, False, False, True]]
# tensor([4, 7])

slice

如果一种索引规则比较常用,我们可以创建一个 slice 对象来储存,以后用这个 slice 对象来索引:

index = slice(1, None, 2)
token = sample[1][index]
# tensor([5, 7])

slice 的规则和间隔访问切片

2024年5月15日

摘要

在使用 PyTorch 的过程中,常规的访问 Tensor 的索引方式并不能够很好的适应一些复杂场景。本文将介绍一些在 PyTorch 中访问 Tensor 的高级索引方式。

索引方式

假设存在一个 ​ 2\times 4 的 2D Tensor,第一排从 0 到 3,第二排从 4 到 7:

import torch

sample = torch.arrange(8).reshape(2, 4)
# tensor([[0, 1, 2, 3],
#         [4, 5, 6, 7]])

基础方式

访问元素

有两种方式可以进行基础的元素访问:

# Divide dimension by comma
element = sample[0, 1]
# tensor(1)

# C-Style
element = sample[0][1]
# tensor(1)

其次,也可以使用负数进行索引,-1 代表最后一个元素,-2 代表倒数第二个,以此类推:

element = sample[-1, -2]
element = sample[-1][-2]
# tensor(6)

访问切片

可以用 a:b 表示选取 [a, b) 的索引:

token = sample[0, 1:3]
token = sample[0][1:3]
# tensor([1, 2])

其中,当 a == b 时,当前维度切片长度为 0:

token = sample[0, 1:1]
token = sample[0][1:1]
# tensor([], dtype=torch.int64)

b == a + 1 时,当前维度切片长度为 1,需要注意的是,切片长度为 1 并不代表这个维度不存在了:

token = sample[0:1, 1:2]
token = sample[0:1][1:2]
# tensor([[1]])

element = sample[0, 1]
element = sample[0][1]
# tensor(1)

可以看到,使用 b == a + 1 的形式作为索引,取出的值虽然和直接用 a 做索引一致,但是维度数量能保持不变,直接使用标量 a 做索引会去除当前维度。

切片中一样可以使用负数索引代表倒数第几个元素。

[a, b) 超出 Tensor 的范围时,超出部分的索引将被忽略。

a 不写则为 1,例如 :bb 不写则为 -1,例如 a:

高级方式

间隔访问切片

除了连续访问,还能够间隔以等差数列索引来访问,格式为 start:end:step,其将生成由 start 开始(闭区间),end 结束(开区间),以 step 为步长的等差数列列表,并将这个列表内的值作为索引实现间隔取值,当 Tensor 中不存在索引列表所需的值时,此索引将被忽略。

token = sample[0, 1:100:2]
token = sample[0][1:100:2]
# tensor([1, 3])

其中,startend 可以是负数,step 不行。

其中,start 不写则为 1end 不写则为 -1step 不写则为 1

列表索引访问

除了使用 start:end:step 生成等差数列索引列表,我们还可以直接手动提供索引列表:

token = sample[1, [1, 3]]
token = sample[1][[1, 3]]
# tensor([5, 7])

Bool 访问

除了使用标量列表索引访问,我们还可以使用 Bool 列表来访问:

token = sample[1, [True, False, False, True]]
token = sample[1][[True, False, False, True]]
# tensor([4, 7])

slice

如果一种索引规则比较常用,我们可以创建一个 slice 对象来储存,以后用这个 slice 对象来索引:

index = slice(1, None, 2)
token = sample[1][index]
# tensor([5, 7])

slice 的规则和间隔访问切片一致。

  • 0