【问题标题】:Converting CSV files to TF Records将 CSV 文件转换为 TF 记录
【发布时间】:2018-01-03 17:31:49
【问题描述】:

我已经运行我的脚本超过 5 个小时了。我有 258 个 CSV 文件要转换为 TF 记录。我编写了以下脚本,正如我所说,我已经运行了 5 个多小时:

import argparse
import os
import sys
import standardize_data
import tensorflow as tf

FLAGS = None
PATH = '/home/darth/GitHub Projects/gru_svm/dataset/train'

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def convert_to(dataset, name):
    """Converts a dataset to tfrecords"""

    filename_queue = tf.train.string_input_producer(dataset)

    # TF reader
    reader = tf.TextLineReader()

    # default values, in case of empty columns
    record_defaults = [[0.0] for x in range(24)]

    key, value = reader.read(filename_queue)

    duration, service, src_bytes, dest_bytes, count, same_srv_rate, \
    serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count, \
    dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate, \
    flag, ids_detection, malware_detection, ashula_detection, label, src_ip_add, \
    src_port_num, dst_ip_add, dst_port_num, start_time, protocol = \
    tf.decode_csv(value, record_defaults=record_defaults)

    features = tf.stack([duration, service, src_bytes, dest_bytes, count, same_srv_rate,
                        serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count,
                        dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate,
                        flag, ids_detection, malware_detection, ashula_detection, src_ip_add,
                        src_port_num, dst_ip_add, dst_port_num, start_time, protocol])

    filename = os.path.join(FLAGS.directory, name + '.tfrecords')
    print('Writing {}'.format(filename))
    writer = tf.python_io.TFRecordWriter(filename)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                example, l = sess.run([features, label])
                print('Writing {dataset} : {example}, {label}'.format(dataset=sess.run(key),
                        example=example, label=l))
                example_to_write = tf.train.Example(features=tf.train.Features(feature={
                    'duration' : _float_feature(example[0]),
                    'service' : _int64_feature(int(example[1])),
                    'src_bytes' : _float_feature(example[2]),
                    'dest_bytes' : _float_feature(example[3]),
                    'count' : _float_feature(example[4]),
                    'same_srv_rate' : _float_feature(example[5]),
                    'serror_rate' : _float_feature(example[6]),
                    'srv_serror_rate' : _float_feature(example[7]),
                    'dst_host_count' : _float_feature(example[8]),
                    'dst_host_srv_count' : _float_feature(example[9]),
                    'dst_host_same_src_port_rate' : _float_feature(example[10]),
                    'dst_host_serror_rate' : _float_feature(example[11]),
                    'dst_host_srv_serror_rate' : _float_feature(example[12]),
                    'flag' : _int64_feature(int(example[13])),
                    'ids_detection' : _int64_feature(int(example[14])),
                    'malware_detection' : _int64_feature(int(example[15])),
                    'ashula_detection' : _int64_feature(int(example[16])),
                    'label' : _int64_feature(int(l)),
                    'src_ip_add' : _float_feature(example[17]),
                    'src_port_num' : _float_feature(example[18]),
                    'dst_ip_add' : _float_feature(example[19]),
                    'dst_port_num' : _float_feature(example[20]),
                    'start_time' : _float_feature(example[21]),
                    'protocol' : _int64_feature(int(example[22])),
                    }))
                writer.write(example_to_write.SerializeToString())
            writer.close()
        except tf.errors.OutOfRangeError:
            print('Done converting -- EOF reached.')
        finally:
            coord.request_stop()

        coord.join(threads)

def main(unused_argv):
    files = standardize_data.list_files(path=PATH)

    convert_to(dataset=files, name='train')

它已经让我想到它可能陷入了无限循环?我想要做的是读取每个 CSV 文件(258 个 CSV 文件)中的所有行,并将这些行写入 TF 记录(当然是一个特征和一个标签)。然后,当没有更多行可用或 CSV 文件已经用完时停止循环。

standardize_data.list_files(path) 是我在另一个模块中编写的函数。我只是将它重新用于这个脚本。它的作用是返回在PATH 中找到的所有文件的列表。请注意,我的PATH 中的文件仅包含 CSV 文件。

【问题讨论】:

    标签: python csv tensorflow file-io dataset


    【解决方案1】:

    string_input_producer 中设置num_epochs=1。另一个注意事项:将这些csv 转换为 tfrecords 可能不会提供您在 tfrecords 中查看的任何优势,这种数据的开销非常高(具有大量单个特征/标签)。您可能想试验这部分。

    【讨论】:

    • 那么,换句话说,您不建议将它们转换为 TF Records 吗?
    • 做这个实验:只转换一个文件,然后检查各自的大小。您的数据对于tfrecords 表示效率不高。每个特征都与标签一起保存,所以我猜它的大小比保存为 csv 的要大。
    • 一个示例 CSV 文件为 10.1 MB,其 tfrecord 等效文件为 9.6 MB
    • 不知道为什么它更少。不是 csv 中的每个样本都是用逗号分隔的浮点数(忽略保存在第一行中的标头),但是在这种情况下 tfrecord 中的每个样本都是“键:值”对,所以不应该是大吗?
    • 我已将其他值转换为 int64 而不是 float。我认为这可以解释为什么它更少。你怎么看?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-08-05
    • 2021-02-28
    • 2012-02-02
    • 2021-06-06
    • 2021-09-11
    • 2012-02-16
    相关资源
    最近更新 更多