【发布时间】:2019-07-21 22:05:12
【问题描述】:
我想稍后使用 tensorflow 梯度来计算其他数量。我需要将目标函数和梯度作为类中的函数进行数值计算(然后在其余套件中使用此类)。但是,我收到以下代码的错误:
import tensorflow as tf
class MyClass:
def __init__(self):
x=tf.Variable(tf.zeros(2))
func = tf.cos(14.5 * x[0] - 0.3) + (x[1] + 0.2) * x[1] + (x[0] + 0.2) * x[0]
diff_func = tf.gradients(func,x)
sess = tf.Session()
def getFunc(self,coords):
return self.sess.run(self.func,feed_dict={self.x:coords})
def getGrad(self,coords):
grad = self.sess.run(self.diff_func,feed_dict={self.x:coords})
return grad
MyClass = MyClass()
MyClass.getFunc([0.362,0.556])
MyClass.getGrad([0.362,0.556])
我得到的错误是:
Traceback(最近一次调用最后一次):
文件“”,第 19 行,在 MyClass.getFunc([0.362,0.556])
文件“”,第 11 行,在 getFunc 中 return self.sess.run(self.func,feed_dict={self.x:coords})
AttributeError: MyClass 实例没有属性 'sess'
不知道如何让这个类正确运行。 谢谢。
【问题讨论】:
-
你可以在这个类之外的顶部定义
sess = tf.Session()。
标签: python class session tensorflow