【问题标题】:Tensorflow Python reading 2 filesTensorFlow Python读取2个文件
【发布时间】:2018-09-06 13:56:33
【问题描述】:

我正在尝试运行以下(缩短的)代码:

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
   while not coord.should_stop():      

      # Run some code.... (Reading some data from file 1)

      coord_dev = tf.train.Coordinator()
      threads_dev = tf.train.start_queue_runners(sess=sess, coord=coord_dev)

      try:
        while not coord_dev.should_stop():

           # Run some other code.... (Reading data from file 2)

      except tf.errors.OutOfRangeError:
        print('Reached end of file 2')
      finally:
        coord_dev.request_stop()
        coord_dev.join(threads_dev) 

except tf.errors.OutOfRangeError:
   print('Reached end of file 1')
finally:
   coord.request_stop()
   coord.join(threads)

上面应该发生的是:

  • 文件 1 是一个 csv 文件,其中包含我的神经网络的训练数据。
  • 文件 2 包含开发集数据。

在训练期间迭代文件 1 时,我偶尔也想计算开发集数据(来自文件 2)的成本和准确性。 但是当内循环读完File 2时,显然会触发异常

“tf.errors.OutOfRangeError”

这也会导致我的代码离开外循环。内循环的异常也简单地作为外循环的异常处理。但是在读完文件 2 之后,我希望我的代码在外循环中继续训练文件 1。

(我删除了一些细节,例如 num_epochs to train 等,以简化代码的可读性)

有人对如何解决这个问题有任何建议吗?我在这方面有点新意。

提前谢谢你!

【问题讨论】:

    标签: python tensorflow machine-learning neural-network


    【解决方案1】:

    解决了。

    显然,使用 queue_runners 不是正确的方法。 Tensorflow 文档表明应该使用 dataset api,这需要时间来理解。下面的代码完成了我之前尝试做的事情。在这里分享,以防其他人也需要它。

    我在 www.github.com/loheden/tf_examples/dataset api 下放了一些额外的训练代码。为了找到完整的例子,我有点挣扎。

    # READING DATA FROM train and validation (dev set) CSV FILES by using INITIALIZABLE ITERATORS
    
    # All csv files have same # columns. First column is assumed to be train example ID, the next 5 columns are feature
    # columns, and the last column is the label column
    
    # ASSUMPTIONS: (Otherwise, decode_csv function needs update)
    # 1) The first column is NOT a feature. (It is most probably a training example ID or similar)
    # 2) The last column is always the label. And there is ONLY 1 column that represents the label.
    #    If more than 1 column represents the label, see the next example down below
    
    feature_names = ['f1','f2','f3','f4','f5']
    record_defaults = [[""], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]
    
    
    def decode_csv(line):
       parsed_line = tf.decode_csv(line, record_defaults)
       label =  parsed_line[-1]      # label is the last element of the list
       del parsed_line[-1]           # delete the last element from the list
       del parsed_line[0]            # even delete the first element bcz it is assumed NOT to be a feature
       features = tf.stack(parsed_line)  # Stack features so that you can later vectorize forward prop., etc.
       #label = tf.stack(label)          #NOT needed. Only if more than 1 column makes the label...
       batch_to_return = features, label
       return batch_to_return
    
    filenames = tf.placeholder(tf.string, shape=[None])
    dataset5 = tf.data.Dataset.from_tensor_slices(filenames)
    dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv))
    dataset5 = dataset5.shuffle(buffer_size=1000)
    dataset5 = dataset5.batch(7)
    iterator5 = dataset5.make_initializable_iterator()
    next_element5 = iterator5.get_next()
    
    # Initialize `iterator` with training data.
    training_filenames = ["train_data1.csv", 
                          "train_data2.csv"]
    
    # Initialize `iterator` with validation data.
    validation_filenames = ["dev_data1.csv"]
    
    with tf.Session() as sess:
        # Train 2 epochs. Then validate train set. Then validate dev set.
        for _ in range(2):     
            sess.run(iterator5.initializer, feed_dict={filenames: training_filenames})
            while True:
                try:
                  features, labels = sess.run(next_element5)
                  # Train...
                  print("(train) features: ")
                  print(features)
                  print("(train) labels: ")
                  print(labels)  
                except tf.errors.OutOfRangeError:
                  print("Out of range error triggered (looped through training set 1 time)")
                  break
    
        # Validate (cost, accuracy) on train set
        print("\nDone with the first iterator\n")
    
        sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames})
        while True:
            try:
              features, labels = sess.run(next_element5)
              # Validate (cost, accuracy) on dev set
              print("(dev) features: ")
              print(features)
              print("(dev) labels: ")
              print(labels)
            except tf.errors.OutOfRangeError:
              print("Out of range error triggered (looped through dev set 1 time only)")
              break  
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-12-06
      • 2023-04-09
      • 2011-11-19
      • 1970-01-01
      • 2015-01-14
      • 2017-03-14
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多