【问题标题】:Finding non-intersection of two pytorch tensors找到两个 pytorch 张量的不相交
【发布时间】:2019-08-02 06:00:48
【问题描述】:

提前感谢大家的帮助!我在 PyTorch 中尝试做的事情类似于 numpy 的 setdiff1d。例如给定以下两个张量:

t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')

预期的输出应该是(排序的或未排序的):

torch.tensor([9, 12, 5])

理想情况下,操作在 GPU 上完成,GPU 和 CPU 之间没有来回。非常感谢!

【问题讨论】:

  • 您可以直接在 Torch 张量上使用 numpy 操作而无需复制:torch.from_numpy(np.setdiff1d(t1.numpy(),t2.numpy()))
  • 非常感谢@romeric,我很抱歉我的问题没有明确表达。我希望为此使用 CUDA 张量并仅在 GPU 上进行操作,而转换为 ndarray 需要先将张量发送回 cpu。

标签: python numpy pytorch


【解决方案1】:

我遇到了同样的问题,但是在使用更大的数组时,建议的解决方案太慢了。以下简单的解决方案适用于 CPU 和 GPU,并且比其他建议的解决方案要快得多:

combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]

【讨论】:

  • 如果 t1 有重复值,结果似乎不同
  • 如果数据不是一维数据,您只需将dim=0 添加到unique 即可!
  • @Leonid 这可以通过在连接它们之前先采用 torch.unique 形式 t1 和 t2 来解决
【解决方案2】:

如果您不想要 for 循环,这可以一次比较所有值。

你也可以轻松获得非交叉点

t1 = torch.tensor([1, 9, 12, 5, 24])
t2 = torch.tensor([1, 24])

# Create a tensor to compare all values at once
compareview = t2.repeat(t1.shape[0],1).T

# Intersection
print(t1[(compareview == t1).T.sum(1)==1])
# Non Intersection
print(t1[(compareview != t1).T.prod(1)==1])
tensor([ 1, 24])
tensor([ 9, 12,  5])

【讨论】:

  • 更改为compareview = t2.expand(t1.shape[0], t2.shape[0]).T 应该可以节省内存,因为它会创建一个视图。
  • 如果t1 包含在t2 中的重复项,它们将在交叉点中出现多次。
【解决方案3】:

如果您不想离开 cuda,解决方法可能是:

t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
    indices = indices & (t1 != elem)  
intersection = t1[indices]  

【讨论】:

  • 用 cpu 迭代元素(使用 for 循环)错过了快速实现 cuda 的意义
猜你喜欢
  • 1970-01-01
  • 2022-10-07
  • 2021-02-15
  • 2021-03-05
  • 2020-07-07
  • 1970-01-01
  • 1970-01-01
  • 2019-08-24
  • 2018-12-01
相关资源
最近更新 更多