【发布时间】:2022-01-24 10:55:13
【问题描述】:
我在 tensorflow map 函数中的数据集结构有问题。 这就是我的数据的样子:
简单
`train_examples = tf.data.Dataset.from_tensor_slices(train_data)
[[0,1,2,3,4,5,...],
[32,33,34,35,36,...]],
真实
print(train_data[0])
[[array([2,539, 400, 513, 398, 523, 485, 533, 568, 566, 402, 565, 491,
570, 576, 539, 351, 538, 297, 539, 262, 564, 313, 581, 370, 589,
421, 514, 314, 501, 370, 489, 420,3]), array([2, 534, 403, 507, 401, 519, 487, 531, 567, 562, 405, 544, 495,
537, 588, 528, 354, 526, 300, 534, 259, 555, 315, 575, 370, 589,
421, 499, 315, 489, 372, 483, 423,3])]]
我转换为管道<TensorSliceDataset shapes: (2, 34), types: tf.int64>的张量
train_examples 包含 17k 行的 2D 张量 [[source],[target]]。
def make_batches(ds):
return (
ds
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.map(lambda x_int,y_int: x_int,y_int, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.prefetch(tf.data.experimental.AUTOTUNE))
train_batches = make_batches(train_examples)
对于地图,我希望数据结构输出分别带有源和目标。我尝试了函数map(prepare, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def prepare(ds):
srcs = tf.ragged.constant(ds.numpy().[0],tf.int64)
trgs = tf.ragged.constant(ds.numpy().[1],tf.int64)
srcs = srcs.to_tensor()
trgs = trgs.to_tensor()
return srcs,trgs
但是 tensorflow 不允许在 map 函数中急切执行。 如果还有其他关于 Tensorflow 中 map 函数使用的遗漏,请告诉我。谢谢。
Tensorflow 版本 = 2.7
【问题讨论】:
-
你能提供一个完整的例子来说明你的数据是什么样的吗?
-
对不起,我编辑了完整的数据示例。谢谢
-
好的,每个样本由两个数组组成,每个数组的形状为
(34,)? -
是的,数据结构包含 2 个数组,每个数组的形状为 (34,)。
-
train_data是数据集 csv 文件的解释。
标签: python tensorflow tensorflow-datasets