TL;DR
建议使用选项 C,定义如下。
完整答案
创建tf.data.Dataset 对象是为了不必将所有对象加载到内存中。因此,默认情况下使用 tf.gather 将不起作用。您可以选择三个选项:
选项 A:将 ds 加载到内存中 + tf.gather
如果您想使用收集,则必须将整个数据集加载到内存中,然后选择一个子集:
m_X_ds = list(m_X_ds) # Load into memory.
m_X_ds = tf.gather(m_X_ds, arr)) # Gather as usual.
print(m_X_ds)
# Example result: <tf.Tensor: shape=(3,), dtype=int32, numpy=array([8, 6, 2], dtype=int32)>
请注意,这并不总是可行的,尤其是对于庞大的数据集。
选项 B:遍历数据集,过滤不需要的样本
您还可以遍历数据集并手动选择具有所需索引的样本。这可以通过filter 和tf.py_function 的组合实现
m_X_ds = m_X_ds.enumerate() # Create index,value pairs in the dataset.
# Create filter function:
def filter_fn(idx, value):
return idx in arr
# The above is not going to work in graph mode
# We are wrapping it with py_function to execute it eagerly
def py_function_filter(idx, value):
return tf.py_function(filter_fn, (idx, value), tf.bool)
# Filter the dataset as usual:
filtered_ds = m_X_ds.filter(py_function_filter)
选项 C:将选项 B 与 tf.lookup.StaticHashTable 结合
选项 B 很好,除了在转换图张量 -> 急切张量 -> 图张量时可以预期性能会受到影响。 tf.py_function 很有用,但要付出代价。
相反,我们可以声明一个字典,其中所需的索引将返回 true,而不存在的索引可能返回 false。这个字典可能看起来像这样
my_table = {3: True, 4: True, 5: True}.
我们不能使用张量作为字典键,但我们可以声明一个tensorflow's hash table 来让我们检查“好”索引。
m_X_ds = m_X_ds.enumerate() # Do not repeat this if executed in Option B.
keys_tensor = tf.constant(arr)
vals_tensor = tf.ones_like(keys_tensor) # Ones will be casted to True.
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=0) # If index not in table, return 0.
def hash_table_filter(index, value):
table_value = table.lookup(index) # 1 if index in arr, else 0.
index_in_arr = tf.cast(table_value, tf.bool) # 1 -> True, 0 -> False
return index_in_arr
filtered_ds = m_X_ds.filter(hash_table_filter)
无论选项 B 或 C,剩下的就是从您的 fileterd 数据集中删除索引。我们可以使用简单的地图,带有 lambda 函数:
final_ds = filtered_ds.map(lambda idx,value: value)
for entry in final_ds:
print(entry)
# tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(13, shape=(), dtype=int32)
# tf.Tensor(6, shape=(), dtype=int32)
祝你好运。