【发布时间】:2020-11-26 00:38:03
【问题描述】:
我正在使用 Tensorflow 2 和 Keras 训练深度学习模型。我使用tf.data.experimental.make_csv_dataset 读取了我的大 CSV 文件,然后将其拆分为训练和测试数据集。但是,我需要将我的训练数据集分成三部分,因为我的深度学习模型需要两组不同层的输入,所以我需要将[x1_train, x2_train],y_train 传递给model.fit。
我的问题是如何将train_dataset 拆分为x1_train,x2_train 和y_train? (有些功能应该在x1_train,有些功能应该在x2_train)。
我的代码:
def get_dataset(file_path, **kwargs):
dataset = tf.data.experimental.make_csv_dataset(
file_path,
batch_size=64,
label_name=LABEL_COLUMN,
na_value="?",
num_epochs=1,
ignore_errors=True,
**kwargs)
return dataset
full_dataset = get_dataset(dataset_path)
full_dataset = full_dataset.shuffle(buffer_size=400000)
train_dataset = full_dataset.take(360000)
test_dataset = full_dataset.skip(360000)
test_dataset = test_dataset.take(40000)
x1_train =train_dataset[:,0:2820]
x2_train =train_dataset[:,2820:2822]
y_train=train_dataset[:,2822]
x1_test =x_test[:,0:2820]
x2_test =x_test[:,2820:2822]
y_test=test_dataset[:,2822]
model.fit([x1_train,x2_train],y_train,validation_data=[x1_test,x2_test],y_test, callbacks=callbacks_list, verbose=1,epochs=EPC)
错误信息:
x1_train =train_dataset[:,0:2820]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: 'TakeDataset' object is not subscriptable
【问题讨论】:
-
在构建数据集之后,在做任何其他事情之前,使用
tf.data.Dataset对象的map方法来拆分每个批次。 -
您能告诉我如何使用 map 来拆分每个批次吗?
-
make_csv_dataset返回的每个批次的第一个元素是一个字典,将列名映射到它们的值。因此,在 map 函数中,您可以将此字典拆分为两个单独的字典(并可能根据模型的输入格式将每个字典中的项目组合成一个单独的张量)。
标签: python tensorflow machine-learning keras tensorflow-datasets