【问题标题】:Tensorflow csv dataset usageTensorFlow csv 数据集使用
【发布时间】:2019-02-13 18:01:05
【问题描述】:

我有一个 csv 文件,格式和数据如下:

ID  nr1 nr2 nr3 nr4 nr5 next_nr
1   1   2   3   4   5   6
2   2   3   4   5   6   7
3   3   4   5   6   7   8
4   4   5   6   7   8   9
5   5   6   7   8   9   10
6   6   7   8   9   10  11
7   7   8   9   10  11  12
8   8   9   10  11  12  13
9   9   10  11  12  13  14
10  10  11  12  13  14  15

所以,包括我的火车数据在内有 10 行。我想使用 tf.contrib.data.CsvDataset 来读取数据。这是阅读它的示例代码:

import tensorflow as tf
import numpy as np

ITERATOR_BATCH_SIZE = 2
NR_EPOCHS = 3

train1_path = 'train1_short.csv'

dataset = tf.contrib.data.CsvDataset(train1_path,
                                     [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32],
                                     header=True)

dataset = dataset.batch(ITERATOR_BATCH_SIZE)

with tf.Session() as sess:

    for i in range (NR_EPOCHS):
        print('\nepoch: ', i)
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        while True:            
            try:
              data_and_target = sess.run([next_element])
            except tf.errors.OutOfRangeError:
              break
            print("\n\n", data_and_target)

当我运行此代码时,我希望输出在每批中包含 2 行数据。但是我得到的数据看起来很奇怪。这是第一个 epoch 的输出:

epoch:  0


 [(array([1., 2.], dtype=float32), array([1., 2.], dtype=float32), array([2., 3.], dtype=float32), array([3., 4.], dtype=float32), array([4., 5.], dtype=float32), array([5., 6.], dtype=float32), array([6., 7.], dtype=float32))]


 [(array([3., 4.], dtype=float32), array([3., 4.], dtype=float32), array([4., 5.], dtype=float32), array([5., 6.], dtype=float32), array([6., 7.], dtype=float32), array([7., 8.], dtype=float32), array([8., 9.], dtype=float32))]


 [(array([5., 6.], dtype=float32), array([5., 6.], dtype=float32), array([6., 7.], dtype=float32), array([7., 8.], dtype=float32), array([8., 9.], dtype=float32), array([ 9., 10.], dtype=float32), array([10., 11.], dtype=float32))]


 [(array([7., 8.], dtype=float32), array([7., 8.], dtype=float32), array([8., 9.], dtype=float32), array([ 9., 10.], dtype=float32), array([10., 11.], dtype=float32), array([11., 12.], dtype=float32), array([12., 13.], dtype=float32))]


 [(array([ 9., 10.], dtype=float32), array([ 9., 10.], dtype=float32), array([10., 11.], dtype=float32), array([11., 12.], dtype=float32), array([12., 13.], dtype=float32), array([13., 14.], dtype=float32), array([14., 15.], dtype=float32))]

相反,我会 - 例如 - 期望第一批喜欢以下内容:

[(array([1., 1., 2., 3., 4., 5., 6], dtype=float32), array([2., 2., 3., 4., 5., 6., 7.], dtype=float32)]

这个问题可能非常微不足道,但我就是不明白为什么它看起来像这样。也许该领域更有经验的人可以立即看到它。

【问题讨论】:

    标签: python tensorflow machine-learning dataset tensorflow-datasets


    【解决方案1】:

    CsvDatset 的每条记录都必须转换为张量。让我知道这是否适合您:

    dataset = tf.contrib.data.CsvDataset(train1_path,
                                         [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32],
                                         header=True, field_delim=' ')
    
    dataset = dataset.map(lambda *x: tf.convert_to_tensor(x))
    dataset = dataset.batch(ITERATOR_BATCH_SIZE)
    
    with tf.Session() as sess:
        for i in range (NR_EPOCHS):
            print('\nepoch: ', i)
            iterator = dataset.make_one_shot_iterator()
            next_element = iterator.get_next()
            while True:            
                try:
                  data_and_target = sess.run(next_element)
                except tf.errors.OutOfRangeError:
                  break
                print("\n\n", data_and_target)
    

    对于我的测试,我必须设置 field_delim 参数才能使其正常工作。

    【讨论】:

    • 这太好了!谢谢你,@MatthewScarpino!略有不同的是,您的代码在我的情况下无法正常工作。但是我删除了 field_delim=' ' 参数,它工作得非常好。它可能与您可能用于 csv 文件的应用程序有关,我不知道。但是这个答案完美地解决了这个问题。
    • 如果不是所有类型都是tf.float32,该怎么办?例如,我有一个tf.string 类型的列和另一个tf.float32 类型的列。如何适当适配tf.convert_to_tensor
    猜你喜欢
    • 1970-01-01
    • 2019-09-17
    • 1970-01-01
    • 2020-11-20
    • 2018-10-01
    • 2021-06-26
    • 2023-04-07
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多