【问题标题】:How to map element in pytorch tensor to id?如何将pytorch张量中的元素映射到id?
【发布时间】:2021-04-10 09:50:59
【问题描述】:

给定一个张量:

A = torch.tensor([2., 3., 4., 5., 6., 7.])

然后,给A 中的每个元素一个id:

id = torch.arange(A.shape[0], dtype = torch.int)   # tensor([0,1,2,3,4,5])

换句话说,A2. 的id 为0,A3. 的id 为1:

2. -> 0
3. -> 1
4. -> 2
5. -> 3
6. -> 4
7. -> 5

那么,我有一个新的张量:

B = torch.tensor([3., 6., 6., 5., 4., 4., 4.])

在pytorch中,Pytorch中有没有办法将B中的每个元素映射到id? 也就是说,我要获取tensor([1, 4, 4, 3, 2, 2, 2]),其中每个元素都是B中元素的id。

【问题讨论】:

    标签: python-3.x pytorch


    【解决方案1】:

    您的问题可以通过慢慢迭代整个B 矩阵并根据A 的所有元素检查其中的每个元素然后检索每个元素的索引来完成:

    In [*]: for x in B:
        ...:     print(torch.where(x==A)[0][0])
        ...:
        ...:
    tensor(1)
    tensor(4)
    tensor(4)
    tensor(3)
    tensor(2)
    tensor(2)
    tensor(2)
    

    这里我使用torch.where 来查找矩阵x==A 中的所有True 元素,其中x 取矩阵B 中每个元素的值。这确实很慢,但它允许您添加一些功能来处理B 的某些元素未出现在矩阵A 中的情况

    用线性代数运算得到你想要的东西的快速而肮脏的方法是:

    In [*]: (B.view(-1,1) == A).int().argmax(dim=1)
    Out[*]: tensor([1, 4, 4, 3, 2, 2, 2])
    

    这个技巧利用了argmax返回dim=1中每个向量的第一个'max'索引这一事实。

    这里的大警告,如果矩阵中不存在该元素,则不会引发错误,并且对于A 中不存在的所有元素,结果将为0

    In [*]: C = torch.tensor([100, 1000, 1, 3, 9999])
    
    In [*]: (C.view(-1,1) == A).int().argmax(dim=1)
    Out[*]: tensor([0, 0, 0, 1, 0])
    

    【讨论】:

      【解决方案2】:

      我认为 PyTorch 中没有这样的函数来映射张量。

      通过将B 中的每个值与B 中的值进行比较来解决这个问题似乎很不合理。

      这里有两种可能的解决方案来解决这个问题。


      使用字典作为地图

      您可以使用字典pure-PyTorch 解决方案不是很多,但很可能是最快和最安全的方法...

      只需创建一个dict来将每个元素映射到一个id,然后用它来映射B

      >>> map = {x.item(): i for i, x in enumerate(A)}
      
      >>> torch.tensor([map[x.item()] for x in B])
      tensor([1, 4, 4, 3, 2, 2, 2])
      

      改变基准方法

      仅使用torch.Tensors 的替代方案。这将要求您要映射的值 - A 的内容 - 是整数,因为它们将用于索引张量。

      1. A的内容编码为one-hot编码:

        >>> A_enc = torch.zeros((int(A.max())+1,)*2)
        >>> A_enc[A, torch.arange(A.shape[0])] = 1
        
        >>> A_enc
        tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 0., 0., 0., 0.],
                [0., 1., 0., 0., 0., 0., 0., 0.],
                [0., 0., 1., 0., 0., 0., 0., 0.],
                [0., 0., 0., 1., 0., 0., 0., 0.],
                [0., 0., 0., 0., 1., 0., 0., 0.],
                [0., 0., 0., 0., 0., 1., 0., 0.]])
        
      2. 我们将使用A_enc 作为映射整数的基础:

        >>> v = torch.argmax(A_enc, dim=0)
        tensor([0, 0, 0, 1, 2, 3, 4, 5])
        
      3. 现在,给定一个整数,例如 x=3,我们可以将其编码为 one-hot-encoding:x_enc = [0, 0, 0, 1, 0, 0, 0, 0]。然后,使用v 对其进行映射。使用简单的点积,您可以获得x_enc 的映射:这里<v/x_enc> 给出了1,这是所需的结果(映射的第一个元素-B)。但是我们不会给出x_enc,而是计算v 和encoded-B 之间的矩阵乘法。先编码B,然后计算矩阵乘法vxB_enc

        >>> B_enc = torch.zeros(A_enc.shape[0], B.shape[0])
        >>> B_enc[B, torch.arange(B.shape[0])] = 1
        
        >>> B_enc
        tensor([[0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 1., 1., 1.],
                [0., 0., 0., 1., 0., 0., 0.],
                [0., 1., 1., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0., 0., 0.]])
        
        >>> v@B_enc.long()
        tensor([1, 4, 4, 3, 2, 2, 2])
        

      注意 - 你必须用 Long 类型定义你的张量。

      【讨论】:

        【解决方案3】:

        for numpy 有一个类似的问题,所以我的回答深受他们的解决方案的启发。我将使用perfplot 比较一些提到的方法。我还将概括问题以将映射应用于张量(您的只是一个特定情况)。

        为了分析,我将假设映射包含张量中的所有唯一元素以及元素的数量小而恒定。

        import torch
        
        
        def apply(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
            mapping = {k.item(): v.item() for k, v in zip(a, ids)}
            return b.clone().apply_(lambda x: mapping.__getitem__(x))
        
        
        def bucketize(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
            mapping = {k.item(): v.item() for k, v in zip(a, ids)}
        
            # From `https://stackoverflow.com/questions/13572448`.
            palette, key = zip(*mapping.items())
            key = torch.tensor(key)
            palette = torch.tensor(palette)
        
            index = torch.bucketize(b.ravel(), palette)
            remapped = key[index].reshape(b.shape)
        
            return remapped
        
        
        def iterate(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
            mapping = {k.item(): v.item() for k, v in zip(a, ids)}
            return torch.tensor([mapping[x.item()] for x in b])
        
        
        def argmax(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
            return (b.view(-1, 1) == a).int().argmax(dim=1)
        
        
        if __name__ == "__main__":
            import perfplot
        
            a = torch.arange(2, 8)
            ids = torch.arange(0, 6)
        
            perfplot.show(
                setup=lambda n: torch.randint(2, 8, (n,)),
                kernels=[
                    lambda x: apply(a, ids, x),
                    lambda x: bucketize(a, ids, x),
                    lambda x: iterate(a, ids, x),
                    lambda x: argmax(a, ids, x),
                ],
                labels=["apply", "bucketize", "iterate", "argmax"],
                n_range=[2 ** k for k in range(25)],
                xlabel="len(a)",
            )
        

        运行它会产生以下情节:

        因此,根据张量中的元素数量,您可以选择argmax 方法(其中提到的注意事项以及必须从0 to N 映射值的限制)、apply 或@987654331 @。

        现在,如果我们增加要映射的元素的数量,比如说数万个,即a = torch.arange(2, 10002)ids = torch.arange(0, 10000),我们会得到以下结果:

        这意味着bucketize 的速度提升仅对更大的数组可见,但仍优于其他方法(argmax 方法已被终止,因此我不得不将其删除)。

        最后,如果我们的映射没有张量中的所有键,我们可以用所有唯一键更新字典:

        mapping = {x.item(): x.item() for x in torch.unique(a)}
        mapping.update({k.item(): v.item() for k, v in zip(a, ids)})
        

        现在,如果您要映射的唯一元素数量级大于计算数组的数量级,则当 bucketizeapply 快时,这可能会改变 n 的值(因为对于应用,您可以更改mapping.__getitem__(x)mapping.get(x, x)

        【讨论】:

          猜你喜欢
          • 2018-01-17
          • 2020-07-20
          • 2021-05-25
          • 2021-09-14
          • 2019-09-10
          • 1970-01-01
          • 1970-01-01
          • 2020-10-04
          • 1970-01-01
          相关资源
          最近更新 更多