【问题标题】:Tensorflow: read variable length data, via Dataset (tfrecord)Tensorflow:通过数据集(tfrecord)读取可变长度数据
【发布时间】:2018-04-09 12:42:44
【问题描述】:

最好的

我想读取一些 TF 记录数据。
这有效,但仅适用于固定长度数据,但现在我想对可变长度数据做同样的事情 VarLenFeature

def load_tfrecord_fixed(serialized_example):

    context_features = {
        'length':tf.FixedLenFeature([],dtype=tf.int64),
        'type':tf.FixedLenFeature([],dtype=tf.string)
    }

    sequence_features = {
        "values":tf.FixedLenSequenceFeature([], dtype=tf.int64)
    }


    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features
    )


    return context_parsed,sequence_parsed

   tf.reset_default_graph()



    with tf.Session() as sess:

        filenames = [fp.name]

        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.map(load_tfrecord_fixed)
        dataset = dataset.repeat()
        dataset = dataset.batch(2)

        iterator = dataset.make_initializable_iterator()
        next_element = iterator.get_next()

        a = sess.run(iterator.initializer)

        for i in range(3):
            a = sess.run(next_element)
            print(a)

结果:

({'length': array([3, 3], dtype=int64), 'type': array([b'FIXED_length', b'FIXED_length'], dtype=object)}, {'values': array([[82,  2,  2],
       [42,  5,  1]], dtype=int64)}) ({'length': array([3, 3], dtype=int64), 'type': array([b'FIXED_length', b'FIXED_length'], dtype=object)}, {'values': array([[2, 3, 1],
       [1, 2, 3]], dtype=int64)}) ({'length': array([3, 3], dtype=int64), 'type': array([b'FIXED_length', b'FIXED_length'], dtype=object)}, {'values': array([[  1, 100, 200],
       [123,  12,  12]], dtype=int64)})

这是我正在尝试使用的地图功能,但最后它给了我一些错误:'(

def load_tfrecord_variable(serialized_example):

    context_features = {
        'length':tf.FixedLenFeature([],dtype=tf.int64),
        'batch_size':tf.FixedLenFeature([],dtype=tf.int64),
        'type':tf.FixedLenFeature([],dtype=tf.string)
    }

    sequence_features = {
        "values":tf.VarLenFeature(tf.int64)
    }


    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features
    )
    #return context_parsed, sequence_parsed (which is sparse)

    # return context_parsed, sequence_parsed
    batched_data = tf.train.batch(
        tensors=[sequence_parsed['values']],
        batch_size= 2,
        dynamic_pad=True
    )

    # make dense data
    dense_data = tf.sparse_tensor_to_dense(batched_data)

    return context_parsed, dense_data

错误:

OutOfRangeError: Attempted to repeat an empty dataset infinitely.
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[], [], [], [?,?,?]], output_types=[DT_INT64, DT_INT64, DT_STRING, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]

During handling of the above exception, another exception occurred:

所以你能帮帮我吗?另外,我每晚都使用 tensorflow。 我不认为我错过了很多......

【问题讨论】:

  • 不要使用tf.train.batch。如果有VarLenFeature,可以使用Dataset.padded_batch批量填充序列。
  • @MaosiChen 当我使用 padded_batch 时,我收到此错误“如果浅层结构是一个序列,输入也必须是一个序列。输入的类型为:。” dataset = dataset.padded_batch(4, padded_shapes=([None]))

标签: python-3.x tensorflow tfrecord


【解决方案1】:
def load_tfrecord_variable(serialized_example):

    context_features = {
        'length':tf.FixedLenFeature([],dtype=tf.int64),
        'batch_size':tf.FixedLenFeature([],dtype=tf.int64),
        'type':tf.FixedLenFeature([],dtype=tf.string)
    }

    sequence_features = {
        "values":tf.VarLenFeature(tf.int64)
    }

    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    length = context_parsed['length']
    batch_size = context_parsed['batch_size']
    type = context_parsed['type']

    values = sequence_parsed['values'].values

    return tf.tuple([length, batch_size, type, values])

# 
filenames = [fp.name]    

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(load_tfrecord_fixed)
dataset = dataset.repeat()
dataset = dataset.padded_batch(
    batch_size, 
    padded_shapes=(
        tf.TensorShape([]),
        tf.TensorShape([]),
        tf.TensorShape([]),
        tf.TensorShape([None])  # if you reshape 'values' in load_tfrecord_variable, add the added dims after None, e.g. [None, 3]
        ),
    padding_values = (
        tf.constant(0, dtype=tf.int64),
        tf.constant(0, dtype=tf.int64),
        tf.constant(""),
        tf.constant(0, dtype=tf.int64)
        )
    )

iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    a = sess.run(iterator.initializer)
    for i in range(3):
        [length_vals, batch_size_vals, type_vals, values_vals] = sess.run(next_element)

【讨论】:

  • 应该将dataset.map() 行改为dataset = dataset.map(load_tfrecord_variable) 吗?
【解决方案2】:

我遇到了同样的问题。我为Voxceleb 音频数据集创建了一个 TFRecord 文件。数据集由 1-20 秒的音频文件组成。

1) 读取音频文件-audio = tf.io.read_file(audio_file_name)

2) 解码waveform, sr = tf.audio.decode_wav(audio)

3) 将其存储为waveform.numpy().flatten()

但是在尝试读取数据时,我最初在功能描述中使用了tf.io.FixedLenFeature,它抛出了一个错误:

InvalidArgumentError: Key: waveform.  Can't parse serialized Example.

Tensorflow 2.x 引入了一个专门用于处理可变长度数据的新功能:RaggedTensor

要从 TFRecord 文件中读取可变长度数据,您只需在功能描述字典中使用 tf.io.RaggedFeature(dtype)

例如:

feature_description = {
    'feature0': tf.io.RaggedFeature(tf.float32),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
     ...
}

使用 RaggedFeature 我能够成功读取数据

【讨论】:

    猜你喜欢
    • 2018-11-06
    • 2018-05-23
    • 2019-07-10
    • 2018-05-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多