【发布时间】:2021-03-19 07:44:45
【问题描述】:
如果给你一个n x n 矩阵的集合,比如m,pytorch 中是否有一个预定义的函数可以将所有这些对角线嵌入到更大的维度矩阵nm x nm 中?
具体来说,我正在寻找的是说你有两个2 x 2 单位矩阵,那么它们的对角线嵌入到4 x 4 矩阵中将是单位4 x 4 矩阵。
类似:
torch.block_diag
但这需要您将每个矩阵作为单独的参数提供。
【问题讨论】:
如果给你一个n x n 矩阵的集合,比如m,pytorch 中是否有一个预定义的函数可以将所有这些对角线嵌入到更大的维度矩阵nm x nm 中?
具体来说,我正在寻找的是说你有两个2 x 2 单位矩阵,那么它们的对角线嵌入到4 x 4 矩阵中将是单位4 x 4 矩阵。
类似:
torch.block_diag
但这需要您将每个矩阵作为单独的参数提供。
【问题讨论】:
您的问题没有说明如何获得 m 张量。假设你有
# channel first tensors
a = torch.ones(4,2,2)
或
# a list of tensors
a = [torch.ones(2,2) for _ in range(4)]
然后你可以在block_diag:
>>> torch.block_diag(*a)
tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1.]])
【讨论】: