【问题标题】:Tensorflow: using a session/graph in methodTensorflow:在方法中使用会话/图形
【发布时间】:2017-02-24 12:02:57
【问题描述】:

我的情况是这样的:

我有一个训练 tensorflow 模型的脚本。在这个脚本中,我实例化了一个提供训练数据的类。该类的初始化反过来又实例化了另一个名为“image”的类来执行各种数据增强操作。

main script -> instantiates data_feed class -> instantiates image class

我的问题是我正在尝试使用 tensorflow 通过传递会话本身或图形来在此图像类中执行一些操作。但我收效甚微。

可行的方法(但太慢)

我现在所拥有的,但工作速度非常缓慢,是这样的(简化的):

class image(object):
    def __init__(self, im):
        self.im = im

    def augment(self):
        aux_im = tf.image.random_saturation(self.im, 0.6)

        sess = tf.Session(graph=aux_im.graph)
        self.im = sess.run(aux_im)

class data_feed(object):
    def __init__(self, data_dir):
        self.images = load_data(data_dir)

    def process_data(self):
        for im in self.images:
            image = image(im)
            image.augment()

if __name__ == "__main__":
    # initialize everything tensorflow related here, including model
    sess = tf.Session()
    # next load the data
    data_feed = data_feed(TRAIN_DATA_DIR)
    train_data = data_feed.process_data()

这种方法有效,但它为每张图片创建一个新会话:

I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
etc ...

行不通的方法(应该更快)

例如,什么不起作用,我不知道为什么,是从我的主脚本传递图形或会话,如下所示:

class image(object):
    def __init__(self, im):
        self.im = im

    def augment(self, tf_sess):
        with tf_sess.as_default():
            aux_im = tf.image.random_saturation(self.im, 0.6)

            self.im = tf_sess.run(aux_im)

class data_feed(object):
    def __init__(self, data_dir, tf_sess):
        self.images = load_data(data_dir)
        self.tf_sess = tf_sess

    def process_data(self):
        for im in self.images:
            image = image(im)
            image.augment(self.tf_sess)

if __name__ == "__main__":
    # initialize everything tensorflow related here, including model
    sess = tf.Session()
    # next load the data
    data_feed = data_feed(TRAIN_DATA_DIR, sess)
    train_data = data_feed.process_data()

这是我得到的错误:

Traceback (most recent call last):
  File "/usr/lib/python2.7/threading.py", line 801, in __bootstrap_inner
    self.run()
  File "/usr/lib/python2.7/threading.py", line 754, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 409, in data_generator_task
    generator_output = next(generator)
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 198, in generate
    yield self.next_batch()
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 192, in next_batch
    X, y, l = self.process_image(json_im, X, y, l)
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 131, in process_image
    im.augment_with_tf(self.tf_sess)
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 85, in augment_with_tf
    self.im = sess.run(saturation, {im_placeholder: np.asarray(self.im)})
  File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 766, in run
    run_metadata_ptr)
  File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 921, in _run
    + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(96, 96, 3), dtype=float32) is not an element of this graph.

任何帮助将不胜感激!

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    不如创建一个 ImageAugmenter 类而不是 Image 类,该类在初始化时接受一个会话,然后使用 Tensorflow 处理您的图像?你可以这样做:

    import tensorflow as tf
    import numpy as np
    
    class ImageAugmenter(object):
        def __init__(self, sess):
            self.sess = sess
            self.im_placeholder = tf.placeholder(tf.float32, shape=[1,784,3])
    
        def augment(self, image):
            augment_op = tf.image.random_saturation(self.im_placeholder, 0.6, 0.8)
            return self.sess.run(augment_op, {self.im_placeholder: image})
    
    class DataFeed(object):
        def __init__(self, data_dir, sess):
            self.images = load_data(data_dir)
            self.augmenter = ImageAugmenter(sess)
    
        def process_data(self):
            processed_images = []
            for im in self.images:
                processed_images.append(self.augmenter.augment(im))
            return processed_images
    
    def load_data(data_dir):
        # True method would read images from disk
        # This is just a mockup
        images = []
        images.append(np.random.random([1,784,3]))
        images.append(np.random.random([1,784,3]))
        return images
    
    if __name__ == "__main__":
        TRAIN_DATA_DIR = '/some/dir/'
        sess = tf.Session()
        data_feed = DataFeed(TRAIN_DATA_DIR, sess)
        train_data = data_feed.process_data()
        print(train_data)
    

    这样你就不会为每张图片创建一个新的会话,它应该给你你想要的。

    注意sess.run() 的调用方式;我传递给它的 feed dict 的键是 上面定义的占位符张量。根据您的错误跟踪,您可能正试图从代码中未定义 im_placeholder 的部分调用 sess.run(),或者它已被定义为 tf.placeholder 以外的部分。

    此外,您可以通过更改ImageAugmenter.augment() 方法来进一步改进代码,以接收lower 和upper 参数作为tf.image.random_saturation() 方法的输入,或者您可以使用特定形状初始化ImageAugmenter 而不是使用例如,它是硬编码的。

    【讨论】:

    • 感谢您的回答,但我仍然收到相同的错误TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(96, 96, 3), dtype=float32) is not an element of this graph.
    • 我已经更新了我的问题以反映完整的回溯,但是上面的代码中没有显示的功能
    • 我用一个可运行的代码 sn-p 更新了我的答案,并进一步解释了为什么会出现该错误。
    猜你喜欢
    • 2017-11-02
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-03-14
    • 1970-01-01
    • 1970-01-01
    • 2020-03-23
    • 1970-01-01
    相关资源
    最近更新 更多