这取决于你的tf.data.Dataset holds是什么类型的数据,我以字符串数据为例,给出了下面的方法,作为解决方案的补充,我也为其他类型的数据提供了辅助函数。
为了将数据写入TFRecords,我们需要将每个数据点转换为字节串,然后使用tf.io.TFRecordsWriter 写入。
让我们看看如何实现这一目标。
首先,让我们定义tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices(["one","two","three","four"])
现在,这个数据集必须被序列化。
def serialize_example(data):
# feature = {
# 'feature0': _int64_feature(feature0),
# 'feature1': _float_feature(feature1),
# 'feature2': _bytes_feature(feature2),
# 'feature3': _bytes_feature(feature3),
# }
feature = {
'feature': _bytes_feature(data)
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
我已经评论了其他类型的功能,因为我在这里考虑的功能是字符串。您可以根据您的数据从特征字典中选择类型。
提供将数据中的每个特征转换为tf.train.Feature的所有函数。
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
# If the value is an eager tensor BytesList won't unpack a string from an EagerTensor.
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
现在,我们都准备好使用TFRecordWriter 创建 TFRecord。
file_path = 'data.tfrecords'
with tf.io.TFRecordWriter(file_path) as writer:
for i in dataset.take(-1):
serialized_example = serialize_example(i)
writer.write(serialized_example)
注意: dataset.take(-1) 将获取您的tf.data.Dataset 中存在的所有记录
现在您的 TFRecord 文件已创建。
file_paths = [file_path]
tfrecord_dataset = tf.data.TFRecordDataset(file_paths)
现在这个数据集中的每个数据点都是由 serialize_example 函数返回的原始字节串。以下函数读取一个 serialized_example 并使用功能描述对其进行解析。
def read_tfrecord(serialized_example):
# feature_description = {
# 'feature0': tf.io.FixedLenFeature((), tf.int64),
# 'feature1': tf.io.FixedLenFeature((), tf.float32),
# 'feature2': tf.io.FixedLenFeature((), tf.string),
# 'feature3': tf.io.FixedLenFeature((), tf.string),
# }
feature_description = {
'feature': tf.io.FixedLenFeature((), tf.string)
}
example = tf.io.parse_single_example(serialized_example, feature_description)
feature = example['feature']
return feature
parsed_dataset = tfrecord_dataset.map(read_tfrecord)
结果:
for data in parsed_dataset.take(-1):
print(data)
tf.Tensor(b'one', shape=(), dtype=string)
tf.Tensor(b'two', shape=(), dtype=string)
tf.Tensor(b'three', shape=(), dtype=string)
tf.Tensor(b'four', shape=(), dtype=string)
如果您使用图像作为数据,您可以关注this beautiful post 来实现相同的目的。