您无需使用batched_dot 创建大型中间数组即可获得最终的三维结果E:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (D, N, H)
B = tt.tensor3('B') # B.shape = (D, H, K)
E = tt.batched_dot(A, B) # E.shape = (D, N, K)
不幸的是,这需要您置换输入和输出数组的维度。虽然这可以在 Theano 中使用 dimshuffle 完成,但似乎 batched_dot 无法处理任意跨步数组,因此在评估 E 时,以下会引发 ValueError: Some matrix has no unit stride:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_perm = A.dimshuffle((2, 0, 1)) # A_perm.shape = (D, N, H)
B_perm = B.dimshuffle((2, 1, 0)) # B_perm.shape = (D, H, K)
E_perm = tt.batched_dot(A_perm, B_perm) # E_perm.shape = (D, N, K)
E = E_perm.dimshuffle((1, 2, 0)) # E.shape = (N, K, D)
batched_dot 沿第一个(大小D)维度使用scan。由于scan 是按顺序执行的,因此如果在 GPU 上运行,这可能比并行计算所有产品的计算效率低。
您可以在batched_dot 方法的内存效率和广播方法中显式使用scan 的并行性之间进行权衡。想法是并行计算大小为M 的批次的完整产品C(假设M 是D 的精确因子),使用scan 迭代批次:
import theano as th
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_batched = A.reshape((N, H, M, D / M))
B_batched = B.reshape((K, H, M, D / M))
E_batched, _ = th.scan(
lambda a, b: (a[:, :, None, :] * b[:, :, :, None]).sum(1),
sequences=[A_batched.T, B_batched.T]
)
E = E_batched.reshape((D, K, N)).T # E.shape = (N, K, D)