【问题标题】:Creating a pytorch tensor binary mask using specific values使用特定值创建 pytorch 张量二进制掩码
【发布时间】:2021-02-22 03:46:27
【问题描述】:

我得到了一个带有整数的 pytorch 2-D 张量,并且 2 个整数总是出现在张量的每一行中。 我想创建一个二进制掩码,在这 2 个整数的 两个 出现之间包含 1,否则为 0。例如,如果整数是 4 和 2,并且一维数组是 [1,1,9,4,6,5,1,2,9,9,11,4,3,6,5,2,3,4] ,返回的掩码为:[0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0]. 是否有任何有效且快速的方法来计算此掩码而无需迭代?

【问题讨论】:

    标签: python pytorch numpy-ndarray tensor binary-matrix


    【解决方案1】:

    可能有点混乱,但它无需迭代即可工作。在下文中,我假设一个示例张量 m 应用解决方案,用它来解释比使用一般符号更容易。

    import torch
    
    vals=[2,8]#let's assume those are the constant values that appear in each row
    
    #target tensor
    m=torch.tensor([[1., 2., 7., 8., 5.],
        [4., 7., 2., 1., 8.]])
    
    #let's find the indexes of those values
    k=m==vals[0]
    p=m==vals[1]
    
    v=(k.int()+p.int()).bool()
    nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2)
    
    #let's create a tiling of the indexes
    q=torch.arange(m.shape[1])
    q=q.repeat(m.shape[0],1)
    
    #you only need two masks, no matter the size of m. see explanation below
    msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
    msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
    
    final_mask=msk_0.int() * msk_1.int()
    
    print(final_mask)
    

    我们得到

    tensor([[0, 1, 1, 1, 0],
            [0, 0, 1, 1, 1]], dtype=torch.int32)
    

    关于mask_0mask_1 这两个掩码,如果不清楚它们是什么,请注意nz_indexes[:,0] 包含,对于m 的每一行,找到vals[0] 的列索引,以及@对于m 的每一行,987654329@ 类似地包含找到vals[1] 的列索引。

    【讨论】:

    • 我的代码在 {nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2) } 行失败,为什么你在 v 上使用非零只取第一列?
    • @Codevan 嗯...是的,我认为如果 vals[0]vals[1] 值之一在一行中发生不止一次,这将失败。在编写您提到的行时,我的假设是每个值每行仅出现一次,因此由于 v.nonzero()[:,0] is the row indexes, I could discard this column as I would already know that v.nonzero()[0,:]` 和 v.nonzero()[1:] 对应于 @ 的行 0 987654337@等等。
    • 我认为一种可能的解决方法是更改​​nonzero(),使用一个返回非零元素的第一次和最后一次出现的函数。
    • 我相应地编辑了我的问题。你现在知道如何适合你的答案了吗?谢谢!
    【解决方案2】:

    完全基于之前的解决方案,这里是修改后的解决方案:

    import torch
    
    vals=[2,8]#let's assume those are the constant values that appear in each row
    
    #target tensor
    m=torch.tensor([[1., 2., 7., 8., 5., 2., 6., 5., 8., 4.],
        [4., 7., 2., 1., 8., 2., 6., 5., 6., 8.]])
    
    #let's find the indexes of those values
    k=m==vals[0]
    p=m==vals[1]
    
    v=(k.int()+p.int()).bool()
    nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],4)
    
    #let's create a tiling of the indexes
    q=torch.arange(m.shape[1])
    q=q.repeat(m.shape[0],1)
    
    #you only need two masks, no matter the size of m. see explanation below
    msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
    msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
    msk_2=(nz_indexes[:,2].repeat(m.shape[1],1).transpose(0,1))<=q
    msk_3=(nz_indexes[:,3].repeat(m.shape[1],1).transpose(0,1))>=q
    
    final_mask=msk_0.int() * msk_1.int() + msk_2.int() * msk_3.int()
    
    print(final_mask)
    

    我们终于得到了

    tensor([[0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
    

    【讨论】:

    • 对于以后的问题,你已经接受了一个答案之后,请不要再改问题了,这不仅不礼貌,而且会让回答你的人看起来很完整白痴。至于解决方案本身,最好找到一些泛化更好的东西:如果我们在数组中有 N 个重复怎么办?拥有 N*2 mask_i 数组似乎不是很优雅/高效。
    • @Ash 是对的,你至少应该对他的回答投赞成票。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-07-31
    • 2019-10-31
    • 2021-02-09
    • 2022-01-09
    • 2018-05-28
    • 1970-01-01
    • 2020-07-25
    相关资源
    最近更新 更多