【发布时间】:2018-09-25 02:29:13
【问题描述】:
我正在尝试在 tensorflow 中保存和恢复我的模型,我尝试搜索并找到了许多教程,但没有一个给出明确的说明,即在恢复模型时我应该使用训练期间使用的相同程序还是只恢复型号??
这是 tensorflow 中的简单线性回归模型:
import numpy as np
import tensorflow as tf
tf.set_random_seed(777)
x_data = [[73., 80., 75.],
[93., 88., 93.],
[89., 91., 90.],
[96., 98., 100.],
[73., 66., 70.]]
y_data = [[152.],
[185.],
[180.],
[196.],
[142.]]
class regression_model():
def __init__(self):
input_x = tf.placeholder(tf.float32,shape=[None,3])
output_y=tf.placeholder(tf.float32,shape=[None,1])
self.placeholder={'input':input_x,'output':output_y}
weights= tf.get_variable('weights',shape=[3,1],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))
bias = tf.get_variable('bias',shape=[1],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))
result=tf.matmul(input_x,weights) + bias
cost=tf.square(result-output_y)
loss=tf.reduce_mean(cost)
train=tf.train.GradientDescentOptimizer(learning_rate=1e-5).minimize(loss)
self.out ={'result':result,'loss':loss,'train':train}
def exe_func(model):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(2001):
out=sess.run(model.out,feed_dict={model.placeholder['input']:x_data,model.placeholder['output']:y_data})
print("loss", out['loss'], "prediction", out['result'])
if __name__=='__main__':
model=regression_model()
exe_func(model)
当我运行时,我得到这个输出:
......
loss 0.73689765 prediction [[152.12286]
[184.14502]
[180.76541]
[196.88777]
[140.74924]]
loss 0.7366613 prediction [[152.12263]
[184.1452 ]
[180.76535]
[196.88771]
[140.74948]]
Process finished with exit code 0
现在我如何保存这个模型以及如何在新文件中恢复?我尝试了这个 stackoverflow question 并做了这样的事情:
def exe_func(model):
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(2001):
out=sess.run(model.out,feed_dict={model.placeholder['input']:x_data,model.placeholder['output']:y_data})
print("loss", out['loss'], "prediction", out['result'])
saver.save(sess, '/Users/exepaul/Desktop/only_rnn_1/')
if __name__=='__main__':
model=regression_model()
exe_func(model)
但我不知道如何使用这个保存的模型以及如何为模型提供输入并获得预测输出?
【问题讨论】:
标签: python python-3.x tensorflow deep-learning regression