【问题标题】:Broadcast tf.matmul with dynamic shapes使用动态形状广播 tf.matmul
【发布时间】:2018-12-25 16:33:31
【问题描述】:

我想在等级 2 和 3 的两个张量之间广播 tf.matmul 操作,其中一个包含“未知”形状的维度(基本上是特定维度中的“无”值)。

问题在于动态维度 tf.reshapetf.broadcast_to 似乎不起作用。

x = tf.placeholder(shape=[None, 5, 10], dtype=tf.float32)
w = tf.ones([10, 20])
y = x @ w
with tf.Session() as sess:
  r1 = sess.run(y, feed_dict={x: np.ones([3, 5, 10])})
  r2 = sess.run(y, feed_dict={x: np.ones([7, 5, 10])})

以上面的代码为例。在这种情况下,我要喂两个不同批次的 3 个和 7 个元素。我希望 r1r2 成为矩阵乘以 w 与这些批次中的 3 或 7 个元素中的每一个的结果。因此,r1r2 的结果形状分别为 (3, 5, 20) 和 (7, 5, 20),但我得到的是:

ValueError: Shape 必须是 2 级,但 'matmul' 是 3 级(操作: 'MatMul') 输入形状:[?,5,10], [10,20]。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    w 可以扩展为 rank-3 张量,其批量大小等于输入的大小。然后就可以进行matmul运算了

    x = tf.placeholder(shape=[None, 5, 10], dtype=tf.float32)
    w = tf.ones([10, 20])
    
    number_batches = tf.shape(x)[0]
    w = tf.tile(tf.expand_dims(w, 0), [number_batches, 1, 1])
    y = x @ w
    with tf.Session() as sess:
      print(sess.run(y, feed_dict={x: np.ones([2, 5, 10])}))
      print(sess.run(y, feed_dict={x: np.ones([3, 5, 10])}))
    

    直播代码here

    【讨论】:

      猜你喜欢
      • 2016-10-29
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-09-12
      • 2018-06-06
      • 1970-01-01
      • 2021-07-20
      相关资源
      最近更新 更多