for numpy 有一个类似的问题,所以我的回答深受他们的解决方案的启发。我将使用perfplot 比较一些提到的方法。我还将概括问题以将映射应用于张量(您的只是一个特定情况)。
为了分析,我将假设映射包含张量中的所有唯一元素以及元素的数量小而恒定。
import torch
def apply(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
mapping = {k.item(): v.item() for k, v in zip(a, ids)}
return b.clone().apply_(lambda x: mapping.__getitem__(x))
def bucketize(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
mapping = {k.item(): v.item() for k, v in zip(a, ids)}
# From `https://stackoverflow.com/questions/13572448`.
palette, key = zip(*mapping.items())
key = torch.tensor(key)
palette = torch.tensor(palette)
index = torch.bucketize(b.ravel(), palette)
remapped = key[index].reshape(b.shape)
return remapped
def iterate(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
mapping = {k.item(): v.item() for k, v in zip(a, ids)}
return torch.tensor([mapping[x.item()] for x in b])
def argmax(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return (b.view(-1, 1) == a).int().argmax(dim=1)
if __name__ == "__main__":
import perfplot
a = torch.arange(2, 8)
ids = torch.arange(0, 6)
perfplot.show(
setup=lambda n: torch.randint(2, 8, (n,)),
kernels=[
lambda x: apply(a, ids, x),
lambda x: bucketize(a, ids, x),
lambda x: iterate(a, ids, x),
lambda x: argmax(a, ids, x),
],
labels=["apply", "bucketize", "iterate", "argmax"],
n_range=[2 ** k for k in range(25)],
xlabel="len(a)",
)
运行它会产生以下情节:
因此,根据张量中的元素数量,您可以选择argmax 方法(其中提到的注意事项以及必须从0 to N 映射值的限制)、apply 或@987654331 @。
现在,如果我们增加要映射的元素的数量,比如说数万个,即a = torch.arange(2, 10002) 和ids = torch.arange(0, 10000),我们会得到以下结果:
这意味着bucketize 的速度提升仅对更大的数组可见,但仍优于其他方法(argmax 方法已被终止,因此我不得不将其删除)。
最后,如果我们的映射没有张量中的所有键,我们可以用所有唯一键更新字典:
mapping = {x.item(): x.item() for x in torch.unique(a)}
mapping.update({k.item(): v.item() for k, v in zip(a, ids)})
现在,如果您要映射的唯一元素数量级大于计算数组的数量级,则当 bucketize 比 apply 快时,这可能会改变 n 的值(因为对于应用,您可以更改mapping.__getitem__(x) 为mapping.get(x, x)。