【问题标题】:Filter data in pytorch tensor在 pytorch 张量中过滤数据
【发布时间】:2019-12-25 11:00:27
【问题描述】:

我有一个张量X[0.1, 0.5, -1.0, 0, 1.2, 0]一样,我想实现一个叫filter_positive()的函数,它可以将正数据过滤成一个新的张量并返回原始张量的索引。例如:

new_tensor, index = filter_positive(X)

new_tensor = [0.1, 0.5, 1.2]
index = [0, 1, 4]

如何在 pytorch 中最有效地实现这个功能?

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    看一下torch.nonzero,它大致相当于np.where。它将二进制掩码转换为索引:

    >>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
    >>> mask = X >= 0
    >>> mask
    tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)
    
    >>> indices = torch.nonzero(mask)
    >>> indices
    tensor([[0],
            [1],
            [3],
            [4],
            [5]])
    
    >>> X[indices]
    tensor([[0.1000],
            [0.5000],
            [0.0000],
            [1.2000],
            [0.0000]])
    

    一个解决方案是这样写:

    mask = X >= 0
    new_tensor = X[mask]
    indices = torch.nonzero(mask)
    

    【讨论】:

      【解决方案2】:

      如果不需要索引,你可以这样做:

      X = X[X > 0]
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2021-02-12
        • 2021-08-20
        • 2022-08-27
        • 2021-02-07
        • 1970-01-01
        • 1970-01-01
        • 2021-10-11
        • 2022-10-17
        相关资源
        最近更新 更多