【问题标题】:Tensorflow Dataset API doubles graph protobuff filesizeTensorflow Dataset API 将图形 protobuff 文件大小加倍
【发布时间】:2023-04-11 11:05:01
【问题描述】:

总结:使用新的 tf.contrib.data.Dataset 会使我的图形 protobuff 文件的大小翻倍,我无法在 Tensorboard 中可视化图形。

详情:

我正在尝试新的 TensorFlow tf.contrib.data.Dataset 功能以及 tf.contrib.learn.Experiment 框架。我的输入数据定义为input functions,它返回特征和标签的张量。

如果我使用tf.train.slice_input_producer 函数创建我的输入函数,如下面的代码块(完整代码here),那么我生成的graph.pbtxt 文件为620M,.meta 文件大小约为165M。

def train_inputs():
    with tf.name_scope('Training_data'):
        x = tf.constant(mnist.train.images.reshape([-1, 28, 28, 1]))
        y = tf.constant(mnist.train.labels)
        sliced_input = tf.train.slice_input_producer(
            tensor_list=[x, y], shuffle=True)
        return tf.train.shuffle_batch(
            sliced_input, batch_size=batch_size,
            capacity=10000, min_after_dequeue=batch_size*10)

现在,如果我使用新的tf.contrib.data.Dataset.from_tensor_slices 创建我的输入函数,就像下面的代码块(完整代码here)一样,那么我生成的graph.pbtxt 文件的大小会翻倍到1.3G,.meta 文件会翻倍大小为330M。

def train_inputs():
    with tf.name_scope('Training_data'):
        images = mnist.train.images.reshape([-1, 28, 28, 1])
        labels = mnist.train.labels
        dataset = tf.contrib.data.Dataset.from_tensor_slices(
            (images, labels))
        dataset = dataset.repeat(None)  # Infinite
        dataset = dataset.shuffle(buffer_size=10000)
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
        next_example, next_label = iterator.get_next()
        return next_example, next_label

现在因为graph.pbtxt 文件太大了,TensorBoard 需要很长时间才能解析这个文件,而且我无法直观地调试我的模型图。 我在Dataset documentation 中发现这种大小的增加来自:“数组的内容将被复制多次”solution 将使用占位符。但是,在这种情况下,我需要将 numpy 数组输入到具有活动会话的占位符中以初始化迭代器:

sess.run(iterator.initializer, feed_dict={features_placeholder: features, labels_placeholder: labels})

但是,在使用 tf.contrib.learn.Experiment 框架时,这似乎超出了我的控制范围。

如何使用 Experiment 框架初始化迭代器的初始化器?或者找到在不增加图表大小的情况下使用 Dataset API 的解决方法?

【问题讨论】:

    标签: python tensorflow skflow


    【解决方案1】:

    我使用tf.train.SessionRunHook 找到了解决问题的方法。我创建了一个SessionRunHook 对象,在创建会话后初始化迭代器:

    class IteratorInitializerHook(tf.train.SessionRunHook):
        def __init__(self):
            super(IteratorInitializerHook, self).__init__()
            self.iterator_initiliser_func = None
    
        def after_create_session(self, session, coord):
            self.iterator_initiliser_func(session)
    

    初始化函数在创建数据集迭代器时设置:

    iterator_initiliser_hook.iterator_initiliser_func = \
        lambda sess: sess.run(
            iterator.initializer,
            feed_dict={images_placeholder: images,
                       labels_placeholder: labels})
    

    我将钩子对象传递给tf.contrib.learn.Experimenttrain_monitorseval_hooks 参数。

    生成的 graph.pbtxt 文件现在只有 500K,而 .meta 文件只有 244K。

    Full example here.

    【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-01-13
    • 1970-01-01
    • 1970-01-01
    • 2012-07-06
    • 2018-09-11
    • 2019-10-29
    • 1970-01-01
    相关资源
    最近更新 更多