我去掉了不相关的东西,以保留格式和缩进。希望现在应该清楚了。以下代码分批读取 N 行的 CSV 文件(N 在顶部的常量中指定)。每行包含一个日期(第一个单元格),然后是一个浮点列表(480 个单元格)和一个单热向量(3 个单元格)。然后,代码在读取这些日期、浮点数和 one-hot 向量时简单地打印它们。它打印它们的地方通常是您实际运行模型并提供这些代替占位符变量的地方。
请记住,这里它将每一行读取为字符串,然后将该行中的特定单元格转换为浮点数,这仅仅是因为第一个单元格更容易作为字符串读取。如果您的所有数据都是数字,那么只需将默认值设置为浮点/整数而不是“a”,并摆脱将字符串转换为浮点数的代码。否则不需要!
我放了一些 cmets 来澄清它在做什么。如果有不清楚的地方请告诉我。
import tensorflow as tf
fileName = 'YOUR_FILE.csv'
try_epochs = 1
batch_size = 3
TD = 1 # this is my date-label for each row, for internal pruposes
TS = 480 # this is the list of features, 480 in this case
TL = 3 # this is one-hot vector of 3 representing the label
# set defaults to something (TF requires defaults for the number of cells you are going to read)
rDefaults = [['a'] for row in range((TD+TS+TL))]
# function that reads the input file, line-by-line
def read_from_csv(filename_queue):
reader = tf.TextLineReader(skip_header_lines=False) # i have no header file
_, csv_row = reader.read(filename_queue) # read one line
data = tf.decode_csv(csv_row, record_defaults=rDefaults) # use defaults for this line (in case of missing data)
dateLbl = tf.slice(data, [0], [TD]) # first cell is my 'date-label' for internal pruposes
features = tf.string_to_number(tf.slice(data, [TD], [TS]), tf.float32) # cells 2-480 is the list of features
label = tf.string_to_number(tf.slice(data, [TD+TS], [TL]), tf.float32) # the remainin 3 cells is the list for one-hot label
return dateLbl, features, label
# function that packs each read line into batches of specified size
def input_pipeline(fName, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
[fName],
num_epochs=num_epochs,
shuffle=True) # this refers to multiple files, not line items within files
dateLbl, features, label = read_from_csv(filename_queue)
min_after_dequeue = 10000 # min of where to start loading into memory
capacity = min_after_dequeue + 3 * batch_size # max of how much to load into memory
# this packs the above lines into a batch of size you specify:
dateLbl_batch, feature_batch, label_batch = tf.train.shuffle_batch(
[dateLbl, features, label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
return dateLbl_batch, feature_batch, label_batch
# these are the date label, features, and label:
dateLbl, features, labels = input_pipeline(fileName, batch_size, try_epochs)
with tf.Session() as sess:
gInit = tf.global_variables_initializer().run()
lInit = tf.local_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
# load date-label, features, and label:
dateLbl_batch, feature_batch, label_batch = sess.run([dateLbl, features, labels])
print(dateLbl_batch);
print(feature_batch);
print(label_batch);
print('----------');
except tf.errors.OutOfRangeError:
print("Done looping through the file")
finally:
coord.request_stop()
coord.join(threads)