【问题标题】:Sampling point pairs from a grid in Pytorch从 Pytorch 中的网格中采样点对
【发布时间】:2020-02-17 05:54:45
【问题描述】:

我需要 PyTorch 中网格中的样本点对。

我有一个大小为 (1 x 500 x 1000) 的张量。我还有一个大小为(1 x 500 x 1000)的掩码张量,表示一个点是否有效。我想从这个网格中采样 200k 点对。换句话说,我想将采样点对的坐标作为大小为 (200k x 4) 的张量,表示所有 200k 点对的 (x1, y1, x2, y2)。对中的所有点都应该是有效点。

这会重复很多次,所以我需要有一个有效的方法来执行这个过程。在 PyTorch 中实现这一点的优雅方式是什么?

【问题讨论】:

    标签: python python-3.x performance numpy pytorch


    【解决方案1】:

    这里不是专家,但我确实花了一些时间尝试一下。
    结果证明对一维数组进行操作要快得多(方法二)。

    import time
    import torch
    class Timer():
        def __init__(self):
            pass
        def __enter__(self):
            self.time = time.time()
        def __exit__(self, *exc):
            print(f'time used: {time.time() - self.time:.2f}s')
    
    # a = torch.rand([1,500,1000])
    m = torch.randint(2, [1, 500, 1000]) # mask tensor
    valid_len = (m==1).nonzero().size()[0] # number of valid points
    rand_one = torch.randint(valid_len, [200000]) # sample 200k of random int
    rand_two = torch.randint(valid_len, [200000]) # sample 200k of random int
    
    # method one
    m0 = m == 1 # mask of shape torch.Size([1, 500, 1000])
    m0 = m0.nonzero() # valid points of shape torch.Size([valid_len, 3])
    m0 = m0[:, 1:] # reshape to shape torch.Size([valid_len, 2])
    with Timer():
        one0 = torch.index_select(m0, 0, rand_one) # take 200k valid points
        two0 = torch.index_select(m0, 0, rand_two) # take 200k valid points again
        coor0 = torch.cat([one0, two0], dim=1) # stack them up
    # >>> time used: 1.05s
    
    # method two
    m1 = m.reshape(-1) # reshape mask to torch.Size([500000])
    m1 = m1==1 # mask of shape torch.Size([500000])
    m1 = m1.nonzero() # valid points of shape torch.Size([valid_len, 1])
    m1 = m1.reshape(-1) # valid points of shape torch.Size([valid_len])
    with Timer():
        one1 = m1.take(rand_one) # take 200k valid points
        two1 = m1.take(rand_two) # again
        # transform them to coordinates and stack them up
        coor1 = torch.stack([one1 // 1000, one1 % 1000, two1 // 1000, two1 % 1000], dim=1)
    # >>> time used: 0.07s
    
    assert torch.sum(coor0 == coor1) == 800000 # make sure consistent result 
    

    干杯

    【讨论】:

      猜你喜欢
      • 2017-03-17
      • 1970-01-01
      • 1970-01-01
      • 2013-06-15
      • 1970-01-01
      • 2021-07-09
      • 1970-01-01
      • 2020-12-18
      • 1970-01-01
      相关资源
      最近更新 更多