【问题标题】:Filter Tensorflow dataset by id按 id 过滤 TensorFlow 数据集
【发布时间】:2021-02-28 14:55:03
【问题描述】:

问题

我正在尝试基于包含我希望子集的索引的 numpy 数组来过滤 Tensorflow 2.4 数据集。该数据集有 1000 张形状 (28,28,1) 的图像。

玩具示例代码

m_X_ds = tf.data.Dataset.from_tensor_slices(list(range(1, 21))).shuffle(10, reshuffle_each_iteration=False)
arr = np.array([3, 4, 5])
m_X_ds = tf.gather(m_X_ds, arr)  # This is the offending code

错误信息

ValueError: Attempt to convert a value (<ShuffleDataset shapes: (), types: tf.int32>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.ShuffleDataset'>) to a Tensor.

研究至今

我找到了 thisthis,但它们特定于它们的用例,而我正在寻找一种更通用的子集方法(即基于外部派生的索引数组)。

我对 Tensorflow 数据集非常陌生,到目前为止,我发现学习曲线非常陡峭。希望能得到一些帮助。提前致谢!

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    TL;DR

    建议使用选项 C,定义如下。

    完整答案

    创建tf.data.Dataset 对象是为了不必将所有对象加载到内存中。因此,默认情况下使用 tf.gather 将不起作用。您可以选择三个选项:

    选项 A:将 ds 加载到内存中 + tf.gather

    如果您想使用收集,则必须将整个数据集加载到内存中,然后选择一个子集:

    m_X_ds = list(m_X_ds)  # Load into memory.
    m_X_ds = tf.gather(m_X_ds, arr))  # Gather as usual.
    print(m_X_ds)  
    # Example result: <tf.Tensor: shape=(3,), dtype=int32, numpy=array([8, 6, 2], dtype=int32)>
    

    请注意,这并不总是可行的,尤其是对于庞大的数据集。

    选项 B:遍历数据集,过滤不需要的样本

    您还可以遍历数据集并手动选择具有所需索引的样本。这可以通过filtertf.py_function 的组合实现

    m_X_ds = m_X_ds.enumerate()  # Create index,value pairs in the dataset.
    
    # Create filter function:
    def filter_fn(idx, value):
        return idx in arr
    
    # The above is not going to work in graph mode
    # We are wrapping it with py_function to execute it eagerly
    def py_function_filter(idx, value):
        return tf.py_function(filter_fn, (idx, value), tf.bool)
    
    # Filter the dataset as usual:
    filtered_ds = m_X_ds.filter(py_function_filter)
    

    选项 C:将选项 B 与 tf.lookup.StaticHashTable 结合

    选项 B 很好,除了在转换图张量 -> 急切张量 -> 图张量时可以预期性能会受到影响。 tf.py_function 很有用,但要付出代价。

    相反,我们可以声明一个字典,其中所需的索引将返回 true,而不存在的索引可能返回 false。这个字典可能看起来像这样

    my_table = {3: True, 4: True, 5: True}.
    

    我们不能使用张量作为字典键,但我们可以声明一个tensorflow's hash table 来让我们检查“好”索引。

    m_X_ds = m_X_ds.enumerate()  # Do not repeat this if executed in Option B.
    
    keys_tensor = tf.constant(arr)
    vals_tensor = tf.ones_like(keys_tensor)  # Ones will be casted to True.
    
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
        default_value=0)  # If index not in table, return 0.
    
    
    def hash_table_filter(index, value):
        table_value = table.lookup(index)  # 1 if index in arr, else 0.
        index_in_arr =  tf.cast(table_value, tf.bool) # 1 -> True, 0 -> False
        return index_in_arr
    
    filtered_ds = m_X_ds.filter(hash_table_filter)
    

    无论选项 B 或 C,剩下的就是从您的 fileterd 数据集中删除索引。我们可以使用简单的地图,带有 lambda 函数:

    final_ds = filtered_ds.map(lambda idx,value: value)
    for entry in final_ds:
        print(entry)
    
    # tf.Tensor(7, shape=(), dtype=int32)
    # tf.Tensor(13, shape=(), dtype=int32)
    # tf.Tensor(6, shape=(), dtype=int32)
    

    祝你好运。

    【讨论】:

    • 感谢您清晰详细的解释!我希望 TF 文档同样简单明了。还有一件事:对于选项 B,你能告诉我如何获得一个数据集没有每个元组中的所有第一个元素(ID)吗?我打算使用过滤后的数据集进行训练。
    • 为了澄清我之前的评论,结果数据集应该只包含过滤后的 3 个形状为 (28,28,1) 的张量。
    • 感谢您的客气话。我编辑了我的答案,包括删除索引和问题的第三个解决方案。万事如意!
    • 嗨@sebastian-sz,很抱歉再次给您带来麻烦,但是当我尝试实现选项C时,filtered_ds = m_X_ds.filter(hash_table_filter) 行出现错误:ValueError: 'predicate' return type must be convertible to a scalar boolean tensor. Was NoneTensorSpec(). 即使使用我的玩具示例代码+选项C也没有不行。你愿意帮忙吗?
    • 如果重要的话,我使用的是 TF v2.4
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-09-26
    • 2016-11-02
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多