【问题标题】:How to create mini-batches using tensorflow.data.experimental.CsvDataset compatible with model's input shape?如何使用与模型输入形状兼容的 tensorflow.data.experimental.CsvDataset 创建小批量?
【发布时间】:2020-09-26 08:48:09
【问题描述】:

我将在 TensorFlow 2 中使用 tensorflow.data.experimental.CsvDataset 训练小批量。但 Tensor 的形状不适合我模型的输入形状。

请告诉我通过 TensorFlow 数据集进行小批量训练的最佳方法是什么。

我尝试如下:

# I have a dataset with 4 features and 1 label
feature = tf.data.experimental.CsvDataset(['C:/data/iris_0.csv'], record_defaults=[.0] * 4, header=True, select_cols=[0,1,2,3])
label = tf.data.experimental.CsvDataset(['C:/data/iris_0.csv'], record_defaults=[.0] * 1, header=True, select_cols=[4])
dataset = tf.data.Dataset.zip((feature, label))

# and I try to minibatch training:
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(4,))])
model.compile(loss='mse', optimizer='sgd')
model.fit(dataset.repeat(1).batch(3), epochs=1)

我遇到了一个错误:

ValueError:检查输入时出错:预期的 dense_6_input 有 形状 (4,) 但得到了形状 (1,) 的数组

因为:CsvDataset() 返回形状为 (features, batch) 的张量,但我需要它的形状为 (batch, features)

参考代码:

for feature, label in dataset.repeat(1).batch(3).take(1):
    print(feature)

# (<tf.Tensor: id=487, shape=(3,), dtype=float32, numpy=array([5.1, 4.9, 4.7], dtype=float32)>, <tf.Tensor: id=488, shape=(3,), dtype=float32, numpy=array([3.5, 3. , 3.2], dtype=float32)>, <tf.Tensor: id=489, shape=(3,), dtype=float32, numpy=array([1.4, 1.4, 1.3], dtype=float32)>, <tf.Tensor: id=490, shape=(3,), dtype=float32, numpy=array([0.2, 0.2, 0.2], dtype=float32)>)

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0 tensorflow-datasets


    【解决方案1】:

    tf.data.experimental.CsvDataset 创建一个数据集,其中数据集的每个元素对应于 CSV 文件中的一行,并由多个张量组成,即每列都有一个单独的张量。因此,首先您需要使用数据集的map 方法将所有这些张量堆叠成一个张量,以便与模型期望的输入形状兼容:

    def map_func(features, label):
        return tf.stack(features, axis=1), tf.stack(label, axis=1)
    
    dataset = dataset.map(map_func).batch(BATCH_SIZE)
    

    【讨论】:

    • 使用map(),形状的张量从(特征,批次)到(批次,特征)变化很好。谢谢。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-12-28
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多