【问题标题】:Keras tensors - Get values with indices coming from another tensorKeras 张量 - 使用来自另一个张量的索引获取值
【发布时间】:2018-03-13 14:30:40
【问题描述】:

假设我有这两个张量:

  • valueMatrix,形如(?, 3),其中?是批量大小
  • indexMatrix,形如(?, 1)

我想在indexMatrix 中包含的索引处从valueMatrix 检索值。

示例(伪代码):

valueMatrix = [[7,15,5],[4,6,8]] -- shape=(2,3) -- type=float 
indexMatrix = [[1],[0]] -- shape = (2,1) -- type=int

我想从这个例子中做类似的事情:

valueMatrix[indexMatrix] --> returns --> [[15],[4]]

比起其他后端,我更喜欢 Tensorflow,但答案必须与使用 Lambda 层或其他适合任务的层的 Keras 模型兼容。

【问题讨论】:

    标签: python-3.x tensorflow keras slice tensor


    【解决方案1】:
    import tensorflow as tf
    valueMatrix = tf.constant([[7,15,5],[4,6,8]])
    indexMatrix = tf.constant([[1],[0]])
    
    # create the row index with tf.range
    row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
    # stack with column index
    idx = tf.stack([row_idx, indexMatrix], axis=-1)
    # extract the elements with gather_nd
    values = tf.gather_nd(valueMatrix, idx)
    
    with tf.Session() as sess:
        print(sess.run(values))
    #[[15]
    # [ 4]]
    

    【讨论】:

    • 太棒了!谢谢 - 我找不到 tf.gather_nd 的 keras 替代品,但无论如何它都可以在 lambda 层内工作。
    猜你喜欢
    • 2016-06-20
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-11-22
    • 2022-01-21
    • 1970-01-01
    • 1970-01-01
    • 2022-10-23
    相关资源
    最近更新 更多