【问题标题】:Produce a dataset of stridded slices from a tfrecords dataset从 tfrecords 数据集生成跨步切片数据集
【发布时间】:2017-10-06 11:07:23
【问题描述】:

继续this 问题和here 的讨论 - 我正在尝试使用 Dataset API 获取可变长度张量的数据集并将它们切割成等长的切片(段)。比如:

Dataset = tf.contrib.data.Dataset
segment_len = 6
batch_size = 16

with tf.Graph().as_default() as g:
    # get the tfrecords dataset
    dataset = tf.contrib.data.TFRecordDataset(filenames).map(
        partial(record_type.parse_single_example, graph=g)).batch(batch_size)
    # zip it with the number of segments we need to slice each tensor
    dataset2 = Dataset.zip((dataset, Dataset.from_tensor_slices(
        tf.constant(num_segments, dtype=tf.int64))))
    it2 = dataset2.make_initializable_iterator()
    def _dataset_generator():
        with g.as_default():
            while True:
                try:
                    (im, length), count = sess.run(it2.get_next())
                    dataset3 = Dataset.zip((
                        # repeat each tensor then use map to take a stridded slice
                        Dataset.from_tensors((im, length)).repeat(count),
                        Dataset.range(count))).map(lambda x, c: (
                            x[0][:, c: c + segment_len],
                            x[0][:, c + 1: (c + 1) + segment_len],
                    ))
                    it = dataset3.make_initializable_iterator()
                    it_init = it.initializer
                    try:
                        yield it_init
                        while True:
                            yield sess.run(it.get_next())
                    except tf.errors.OutOfRangeError:
                        continue
                except tf.errors.OutOfRangeError:
                    return
    # Dataset.from_generator need tensorflow > 1.3 !
    das_dataset = Dataset.from_generator(
        _dataset_generator,
        (tf.float32, tf.float32),
        # (tf.TensorShape([]), tf.TensorShape([]))
    )
    das_dataset_it = das_dataset.make_one_shot_iterator()


with tf.Session(graph=g) as sess:
    while True:
        print(sess.run(it2.initializer))
        print(sess.run(das_dataset_it.get_next()))

当然我不想在生成器中传递会话,但这应该通过链接中给出的技巧来解决(创建一个虚拟数据集并映射另一个的迭代器)。上面的代码因圣经而失败:

tensorflow.python.framework.errors_impl.InvalidArgumentError: TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.framework.ops.Operation'>.
         [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_FLOAT], token="pyfunc_1"](arg0)]]
         [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[<unknown>, <unknown>], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

我猜这是因为我尝试产生迭代器的初始化程序,但我的问题基本上是我是否可以使用数据集 API 实现我正在尝试的所有内容。

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    从嵌套的Dataset 构建Dataset 的最简单方法是使用Dataset.flat_map() 转换。此转换将一个函数应用于输入数据集的每个元素(在您的示例中为dataset2),该函数返回一个嵌套的Dataset(在您的示例中很可能是dataset3),然后转换将所有嵌套数据集展平为单个Dataset

    dataset2 = ...  # As above.
    
    def get_slices(im_and_length, count):
      im, length = im_and_length
      # Repeat each tensor then use map to take a strided slice.
      return Dataset.zip((
          Dataset.from_tensors((im, length)).repeat(count),
          Dataset.range(count))).map(lambda x, c: (
              x[0][:, c + segment_len: (c + 1) + segment_len],
              x[0][:, c + 1 + segment_len: (c + 2) + segment_len],
      ))
    
    das_dataset = dataset2.flat_map(get_slices)
    

    【讨论】:

    • 非常感谢 - 我没想到 flat_map 是这项工作的工具
    • 仅供参考,这可能无法与 MonitoredTrainingSession 配合使用 - 迭代器有时会出乎意料地先进(因为它绑定到 model.cost 之类的摘要?) - 或者我可能完全错了,这是我的错。将不得不进行更多调查,但与此同时,由于在 github 上讨论了 MonitoredTrainingSession 和数据集集成,我只是注意到所以你也要记住这一点——也就是说,我们至少必须警告人们小心推进隐藏在内部的操作中的迭代器MonitoredTrainingSession。
    猜你喜欢
    • 2019-04-19
    • 1970-01-01
    • 2016-03-27
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-12-14
    • 2021-02-03
    相关资源
    最近更新 更多