【问题标题】:Best way to vectorize generating a batch of randomly rotated matrices in Numpy/PyTorch?在 Numpy/PyTorch 中矢量化生成一批随机旋转矩阵的最佳方法?
【发布时间】:2020-05-18 13:58:59
【问题描述】:

我想根据初始起始矩阵(例如,形状为(4096, 3))生成一批随机旋转的矩阵,其中应用于批处理中每个矩阵的旋转是从一个一组旋转矩阵(在我原始帖子的代码中,我只想从 8 个可能的旋转角度中随机选择)。因此,我最终得到的是一个形状为 (batch_size, 4096, 3) 的张量。

我目前的方法是预先制作可能的旋转矩阵(因为我只处理 8 个可能的随机旋转),然后使用 for 循环通过随机选择八个预制矩阵中的一个来生成批次批次中每个项目的旋转矩阵。这不是超级高效,所以我希望以某种方式将整个过程矢量化。

现在,这就是我循环一批以逐个生成一批旋转矩阵的方式:

for view_i in range(batch_size):
        # Get rotated view grid points randomly
        idx = torch.randint(0, 8, (1,))
        pointsf = rotated_points[idx]

在下面的代码中,我生成了一组预制的随机旋转矩阵,这些矩阵是从批次的 for 循环中随机选择的。

make_3d_grid 函数生成一个(grid_dim * grid_dim * grid_dim, 3) 形状的矩阵(基本上是一个由 x、y、z 坐标点组成的二维数组)。 get_rotation_matrix 函数返回一个 (3, 3) 旋转矩阵,其中 theta 用于绕 x 轴旋转。

rotated_points = []
grid_dim = 16

pointsf = make_3d_grid((-1,)*3, (1,)*3, (grid_dim,)*3)

view_angles = torch.tensor([0, np.pi / 4.0, np.pi / 2.0, 3 * np.pi / 4.0, np.pi, 5 * np.pi / 4.0, 3 * np.pi / 2.0, 7 * np.pi / 4.0])

for i in range(len(view_angles)):
    theta = view_angles[i]

    rot = get_rotation_matrix(theta, torch.tensor(0.0), torch.tensor(0.0))

    pointsf_rot = torch.mm(pointsf, rot)

    rotated_points.append(pointsf_rot)

在矢量化这方面的任何帮助将不胜感激!如果这方面的代码可以在 Numpy 中完成,那也可以正常工作,因为我可以自己将其转换为 PyTorch。

【问题讨论】:

    标签: python numpy matrix pytorch


    【解决方案1】:

    您可以将旋转矩阵预先生成为(batch_size, 3, 3) 数组,然后乘以广播到(batch_size, N, 3)(N, 3) 点数组。

    rotated_points = np.dot(pointsf, rots)
    

    np.dot 将在pointsf 的最后一个轴和rots 的倒数第二个轴上求和积,将pointsf 的尺寸放在首位。这意味着您的结果将是(N, batch_size, 3) 而不是(batch_size, N, 3)。您当然可以通过简单的轴交换来解决此问题:

    rotated_points = np.dot(pointsf, rots).transpose(1, 0, 2)
    

    rotated_points = np.swapaxes(np.dot(pointsf, rots), 0, 1)
    

    不过,我建议您将 rots 设为您之前所拥有的逆(转置)旋转矩阵。在这种情况下,您可以计算:

    rotated_points = np.dot(transposed_rots, pointsf.T)
    

    您应该能够相当简单地将np.dot 转换为torch.mm

    【讨论】:

    • @Yuerno 很高兴它对你有用。欢迎来到精彩的广播世界。
    猜你喜欢
    • 1970-01-01
    • 2019-10-18
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-06-30
    • 1970-01-01
    • 2021-06-04
    • 1970-01-01
    相关资源
    最近更新 更多