【问题标题】:Split a dataset issue in Tensorflow dataset API在 Tensorflow 数据集 API 中拆分数据集问题
【发布时间】:2018-12-30 15:07:59
【问题描述】:

我正在使用tf.contrib.data.make_csv_dataset 读取一个csv 文件以形成一个数据集,然后我使用命令take() 形成另一个只有一个元素的数据集,但它仍然返回所有元素。

这里有什么问题?我带来了下面的代码:

import tensorflow as tf
import os
tf.enable_eager_execution()

# Constants

column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']
batch_size   = 1
feature_names = column_names[:-1]
label_name = column_names[-1]

# to reorient data strucute
def pack_features_vector(features, labels):
  """Pack the features into a single array."""
  features = tf.stack(list(features.values()), axis=1)
  return features, labels

# Download the file
train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"
train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),
                                       origin=train_dataset_url)

# form the dataset
train_dataset = tf.contrib.data.make_csv_dataset(
train_dataset_fp,
batch_size, 
column_names=column_names,
label_name=label_name,
num_epochs=1)

# perform the mapping
train_dataset = train_dataset.map(pack_features_vector)

# construct a databse with one element 
train_dataset= train_dataset.take(1)

# inspect elements
for step in range(10):
    features, labels = next(iter(train_dataset))
    print(list(features))

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    基于this 的答案,我们可以使用Dataset.take()Dataset.skip() 拆分数据集:

    train_size = int(0.7 * DATASET_SIZE)
    
    train_dataset = full_dataset.take(train_size)
    test_dataset = full_dataset.skip(train_size)
    

    如何修复你的代码?

    不要在循环中多次创建迭代器,而是使用一个迭代器:

    # inspect elements
    for feature, label in train_dataset:
        print(feature)
    

    您的代码中发生了什么会导致这种行为?

    1) 内置 python iter 函数从对象中获取迭代器,或者对象本身必须提供自己的迭代器。所以当你调用iter(train_dataset)时,就相当于调用Dataset.make_one_shot_iterator()

    2) 默认情况下,tf.contrib.data.make_csv_dataset() 中的 shuffle 参数为 True (shuffle=True)。因此,每次调用 iter(train_dataset) 时,它都会创建包含不同数据的新迭代器。

    3) 最后,当通过for step in range(10) 循环遍历时,类似的是,您创建了 10 个不同的迭代器,大小为 1,每个迭代器都有自己的数据,因为它们被打乱了。

    建议:如果你想避免这样的事情在循环外初始化(创建)迭代器:

    train_dataset = train_dataset.take(1)
    iterator = train_dataset.make_one_shot_iterator()
    # inspect elements
    for step in range(10):
        features, labels = next(iterator)
        print(list(features))
        # throws exception because size of iterator is 1
    

    【讨论】:

      猜你喜欢
      • 2020-06-24
      • 1970-01-01
      • 1970-01-01
      • 2018-12-10
      • 2020-12-29
      • 1970-01-01
      • 2019-02-08
      • 1970-01-01
      • 2022-01-01
      相关资源
      最近更新 更多