【问题标题】:How to filter tensor from queue based on some predicate in tensorflow?如何根据张量流中的一些谓词从队列中过滤张量?
【发布时间】:2015-11-24 20:52:23
【问题描述】:

如何使用谓词函数过滤存储在队列中的数据?例如,假设我们有一个存储特征张量和标签的队列,我们​​只需要那些满足谓词的。我尝试了以下实现但没有成功:

feature, label = queue.dequeue()
if (predicate(feature, label)):
    enqueue_op = another_queue.enqueue(feature, label)

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    最直接的方法是使批次出队,通过谓词测试运行它们,使用tf.where 生成与谓词匹配的密集向量,并使用tf.gather 收集结果,并将该批次排入队列。如果您希望自动发生这种情况,您可以在第二个队列上启动队列运行器 - 最简单的方法是使用 tf.train.batch

    例子:

    import numpy as np
    import tensorflow as tf
    
    a = tf.constant(np.array([5, 1, 9, 4, 7, 0], dtype=np.int32))
    
    q = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
    enqueue = q.enqueue_many([a])
    dequeue = q.dequeue_many(6)
    predmatch = tf.less(dequeue, [5])
    selected_items = tf.reshape(tf.where(predmatch), [-1])
    found = tf.gather(dequeue, selected_items)
    
    secondqueue = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
    enqueue2 = secondqueue.enqueue_many([found])
    dequeue2 = secondqueue.dequeue_many(3) # XXX, hardcoded
    
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(enqueue)  # Fill the first queue
      sess.run(enqueue2) # Filter, push into queue 2
      print sess.run(dequeue2) # Pop items off of queue2
    

    谓词产生一个布尔向量; tf.where 生成真实值索引的密集向量,tf.gather 根据这些索引从原始张量中收集项目。

    在这个例子中,很多东西都是硬编码的,当然,你需要在现实中进行非硬编码,但希望它显示你正在尝试做的事情的结构(创建一个过滤管道)。在实践中,您希望 QueueRunners 在那里保持自动搅动。使用tf.train.batch 对自动处理该问题非常有用——请参阅Threading and Queues 了解更多详细信息。

    【讨论】:

    • 是否可以为 SparseTensors 做类似的事情? seemsgather 对他们不起作用。
    • 嘿-谢谢!它仍然是最直接的方式吗?也不需要 numpy 导入是吗?
    • 我想是这样,但我会检查更多。仅此示例运行需要 numpy 导入,因为我使用 numpy 创建了“a”作为常量。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2018-08-30
    • 2022-01-06
    • 2019-12-28
    • 1970-01-01
    • 2022-01-24
    • 2018-07-19
    • 1970-01-01
    相关资源
    最近更新 更多