【问题标题】:Tensorflow: stack all row pairs from a tensorTensorflow:从张量中堆叠所有行对
【发布时间】:2018-05-25 20:54:45
【问题描述】:

给定一个张量t=[[1,2], [3,4]],我需要产生ts=[[1,2,1,2], [1,2,3,4], [3,4,1,2], [3,4,3,4]]。也就是说,我需要将所有行对堆叠在一起。 重要:张量具有维度 [None, 2],即。第一个维度是可变的。

我试过了:

  • 使用tf.while_loop 生成索引列表idx=[[0, 0], [0, 1], [1, 0], [1, 1]],然后使用tf.gather(ts, idx)。这可行,但很混乱,我不知道如何处理渐变。
  • 2 for 循环遍历 tf.unstack(t),将堆叠的行添加到缓冲区,然后 tf.stack(buffer)。如果第一个维度是可变的,这将不起作用。
  • 在广播中寻找灵感。例如,给定 x=t.expand_dims(t, 0), y=t.expand_dims(t, 1), s=tf.reshape(tf.add(x, y), [-1, 2]) s 将是 [[2, 4], [4, 6], [4, 6], [6, 8]],即。每行组合的总和。但是我怎样才能做堆叠而不是总和?我已经失败了 2 天 :)

【问题讨论】:

    标签: python tensorflow machine-learning


    【解决方案1】:

    tf.meshgrid() 的解决方案和一些重塑:

    import tensorflow as tf
    import numpy as np
    
    t = tf.placeholder(tf.int32, [None, 2])
    num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions
    
    # Getting pair indices using tf.meshgrid:
    idx_range = tf.range(num_rows)
    pair_indices = tf.stack(tf.meshgrid(*[idx_range, idx_range]))
    pair_indices = tf.transpose(pair_indices, perm=[1, 2, 0])
    
    # Finally gathering the rows accordingly:
    res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))
    
    with tf.Session() as sess:
        print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
        # [[1 2 1 2]
        #  [3 4 1 2]
        #  [5 6 1 2]
        #  [1 2 3 4]
        #  [3 4 3 4]
        #  [5 6 3 4]
        #  [1 2 5 6]
        #  [3 4 5 6]
        #  [5 6 5 6]]
    

    使用笛卡尔积的解决方案:

    import tensorflow as tf
    import numpy as np
    
    t = tf.placeholder(tf.int32, [None, 2])
    num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions
    
    # Getting pair indices by computing the indices cartesian product:
    row_idx = tf.range(num_rows)
    row_idx_a = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 1), [1, num_rows]), 2)
    row_idx_b = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 0), [num_rows, 1]), 2)
    pair_indices = tf.concat([row_idx_a, row_idx_b], axis=2)
    
    # Finally gathering the rows accordingly:
    res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))
    
    with tf.Session() as sess:
        print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
        # [[1 2 1 2]
        #  [1 2 3 4]
        #  [1 2 5 6]
        #  [3 4 1 2]
        #  [3 4 3 4]
        #  [3 4 5 6]
        #  [5 6 1 2]
        #  [5 6 3 4]
        #  [5 6 5 6]]
    

    【讨论】:

      【解决方案2】:

      可以通过以下方式实现:

      tf.concat([tf.tile(tf.expand_dims(t,1), [1, tf.shape(t)[0], 1]), tf.tile(tf.expand_dims(t,0), [tf.shape(t)[0], 1, 1])], axis=2)
      

      详细步骤:

      t = tf.placeholder(tf.int32, shape=[None, 2])
      #repeat each row of t
      d = tf.tile(tf.expand_dims(t,1), [1, tf.shape(t)[0], 1])
      #Output:
      #[[[1 2] [1 2]]
      # [[3 4] [3 4]]]
      
      #repeat the entire input t
      e = tf.tile(tf.expand_dims(t,0), [tf.shape(t)[0], 1, 1])
      #Output:
      #[[[1 2] [3 4]]
      # [[1 2] [3 4]]]
      
      #concat
      f = tf.concat([d, e], axis=2)
      
      with tf.Session() as sess:
         print(sess.run(f, {t:np.asarray([[1,2],[3,4]])}))  
      #Output
      #[[[1 2 1 2]
      #[1 2 3 4]]
      #[[3 4 1 2]
      #[3 4 3 4]]]
      

      【讨论】:

        猜你喜欢
        • 2018-06-05
        • 1970-01-01
        • 1970-01-01
        • 2020-03-05
        • 1970-01-01
        • 2020-12-23
        • 1970-01-01
        • 2017-09-17
        • 1970-01-01
        相关资源
        最近更新 更多