一、Tensorflow 的基础使用
graphs
session
tensor
Variable
用graph表示计算任务,graphs中的节点是op(operation),一个op获得0个或多个Tensor,执行计算之后产生0个或多个Tensor,Tensor可看作是一个n维的数组或列表,图必须在Session中启动
结构如图所示:
程序实践:
1.创建回话:
import tensorflow as tf
# 创建常量op
m1=tf.constant([[3,3]])
m2=tf.constant([[3],[4]])
res=tf.matmul(m1,m2)
# print(res)不能直接输出,因为在图中需要启动会话
# 定义会话,法一:
sess=tf.Session()
print(sess.run(res))
print(res)
sess.close()
# 法二:不需要执行关闭操作,可以自己关闭
with tf.Session() as sess:
res=sess.run(res)
print(res)
2.变量使用
x=tf.Variable([1,2])
a=tf.constant([3,3])
sub=tf.subtract(x,a)
add=tf.add(x,sub)
# 变量要注意进行初始化操作
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(sub))
print(sess.run(add))
# 创建变量
ini_num=tf.Variable(0)
# 创建op,加一
new_num=tf.add(ini_num,1)
# 赋值op
update_num=tf.assign(ini_num,new_num)
ini=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(ini)
for i in range(10):
print(sess.run(update_num))
3.Fetch和Feed
- Fetch表示在一个回话中可以运行多个op,在run部分加入多个op,可以选择最后以列表还是元组的形式呈现
a=tf.constant(3)
b=tf.constant(5)
c=tf.constant(2)
add=tf.add(a,b)
mul=tf.multiply(b,add)
with tf.Session() as sess:
res=sess.run([add,mul])
print(type(res),res)
- Feed是创建占位符, 具体的值可以在run时候再传入,传入值的形式是字典形式,即a,b表示的是字典中的key,在run部分再加入建立字典的value
#创建占位符
a=tf.placeholder(tf.float32)
b=tf.placeholder(tf.float32)
mul=tf.multiply(a,b)
# 在session中以字典形式传入值
with tf.Session() as sess:
res=sess.run(mul,feed_dict={a:[3],b:[5]})
print(res)
4.线性模型预测模拟
import tensorflow as tf
import numpy as np
# 创建初始数据
input_data=np.random.rand(100)
y=input_data*0.5+3
# 建立模型
k=tf.Variable(0.)
b=tf.Variable(0.)
y_predict=tf.add(tf.multiply(k,input_data),b)
# 建立loss和优化器
loss=tf.reduce_mean(tf.square(y-y_predict))
optimizer=tf.train.GradientDescentOptimizer(0.3)
tri=optimizer.minimize(loss)
ini=tf.global_variables_initializer()
# 建立session
with tf.Session() as sess:
sess.run(ini)
# 迭代200次
for i in range(201):
sess.run(tri)
if i%20==0:
print(i,sess.run([k,b]))
最后结果为:
0 [0.99093753, 1.9482267]
20 [0.8271484, 2.8253014]
40 [0.63067365, 2.9302197]
60 [0.5521954, 2.9721274]
80 [0.5208487, 2.9888668]
100 [0.50832766, 2.995553]
120 [0.50332654, 2.9982235]
140 [0.5013287, 2.9992905]
160 [0.50053096, 2.9997165]
180 [0.5002122, 2.9998868]
200 [0.50008476, 2.9999547]
需要注意的是:
最后一步输出k,b的值时,也需要启动sess.run,并且之后需要用列表或元组的形式输出,否则会报错。TypeError: input must be a dictionary