【问题标题】:Tensorflow session in a class in python?python中的一个类中的Tensorflow会话?
【发布时间】:2019-07-21 22:05:12
【问题描述】:

我想稍后使用 tensorflow 梯度来计算其他数量。我需要将目标函数和梯度作为类中的函数进行数值计算(然后在其余套件中使用此类)。但是,我收到以下代码的错误:

import tensorflow as tf
class MyClass:
    def __init__(self):
        x=tf.Variable(tf.zeros(2))
        func = tf.cos(14.5 * x[0] - 0.3) + (x[1] + 0.2) * x[1] + (x[0] + 0.2) * x[0]
        diff_func = tf.gradients(func,x)

        sess = tf.Session()

    def getFunc(self,coords):
        return self.sess.run(self.func,feed_dict={self.x:coords})

    def getGrad(self,coords):
        grad = self.sess.run(self.diff_func,feed_dict={self.x:coords})
        return grad

MyClass = MyClass()
MyClass.getFunc([0.362,0.556])
MyClass.getGrad([0.362,0.556])

我得到的错误是:

Traceback(最近一次调用最后一次):

文件“”,第 19 行,在 MyClass.getFunc([0.362,0.556])

文件“”,第 11 行,在 getFunc 中 return self.sess.run(self.func,feed_dict={self.x:coords})

AttributeError: MyClass 实例没有属性 'sess'

不知道如何让这个类正确运行。 谢谢。

【问题讨论】:

  • 你可以在这个类之外的顶部定义sess = tf.Session()

标签: python class session tensorflow


【解决方案1】:

基本上,'self' 告诉哪些变量和方法属于一个类。所以你必须告诉 (x, func, diff_func 和 sess) 属于 MyClass。所以修改代码如下:

import tensorflow as tf


class MyClass:
    def __init__(self):
        self.x = tf.Variable(tf.zeros(2))
        self.func = tf.cos(14.5 * self.x[0] - 0.3) + (self.x[1] + 0.2) * self.x[1] + (self.x[0] + 0.2) * self.x[0]
        self.diff_func = tf.gradients(self.func, self.x)

        self.sess = tf.Session()

    def getFunc(self, coords):
        return self.sess.run(self.func, feed_dict={self.x: coords})

    def getGrad(self, coords):
        grad = self.sess.run(self.diff_func, feed_dict={self.x: coords})
        return grad


MyClass = MyClass()
MyClass.getFunc([0.362, 0.556])
print(MyClass.getGrad([0.362, 0.556]))

【讨论】:

    【解决方案2】:

    sess = tf.Session() 替换为self.sess = tf.Session()

    【讨论】:

    • 没用。仍然错误“回溯(最近一次调用最后一次):文件“”,第 19 行,在 MyClass.getFunc([0.362,0.556])文件“",第 11 行,在 getFunc 中返回 self.sess.run(self.func,feed_dict={self.x:coords}) AttributeError: MyClass instance has no attribute 'func'"
    • 这个问题的解决方法类似。如果您希望func 在课堂上的任何地方都可以访问,您必须写self.func。在您发布的 sn-p 中,__init__ 中定义的变量只能从 __init__ 中访问。 This post 详细说明了它的工作原理。更广泛地说,自学 Python 和 TensorFlow 是一项挑战。同时学习它们是一种挫败感。在继续使用 TensorFlow 之前,您可以先了解 Python 的基础知识,从而省去一些麻烦。快乐的黑客攻击!
    • 我有一个非常具体的问题。我已经制定了一个可以重现问题的最小代码。提供正确的完整工作代码不是更有效吗?然后,我可以在它的基础上处理更复杂的情况。
    【解决方案3】:

    如果你在构造函数中创建一个变量,即__init__。然后您需要关联self 以表明该变量属于该特定类。 本例中__init__MyClass的构造函数,那么定义sess需要在self之前关联self.var_name。这样类函数getFunc()getGrad() 就可以理解sess 变量是一个对象变量并且可以使用它。 所以只需在构造函数中将self 替换为self.sess

    【讨论】:

      猜你喜欢
      • 2018-02-07
      • 2017-01-20
      • 2020-02-10
      • 2016-12-13
      • 1970-01-01
      • 2020-01-04
      • 2010-10-23
      • 2018-03-18
      • 2019-05-08
      相关资源
      最近更新 更多