hisweety

实现流程

 

1、准备数据

2、全连接结果计算

3、损失优化(梯度下降)

4、模型评估(计算准确性)

5、加入tensorboard图

6、使用训练后的模型进行预测

 

 1 def full_connect():
 2     #使用占位符时,tersorflow2.X以上会出现tf.placeholder() is not compatible with eager execution报错,需要加下面这段语,避免程序报此错误。
 3     tf.compat.v1.disable_eager_execution()
 4     #获取真实的数据
 5     mnist = input_data.read_data_sets("./tmp/mnist/", one_hot=True)
 6     #1、建立数据的占位符 ,X[None,784] y_true [None,10]
 7     with tf.compat.v1.variable_scope(\'data\'):
 8         x=tf.compat.v1.placeholder(tf.float32,[None,784])
 9         y_true=tf.compat.v1.placeholder(tf.int32,[None,10])
10 
11     #2、建立一个全链接层的神经网络 w[784,10],b=[10]
12     with tf.compat.v1.variable_scope(\'fc_model\'):
13         #随机初始化权重和偏置,权重和偏置后面会跟着训练自动优化
14         weight=tf.Variable(tf.compat.v1.random_normal([784,10],mean=0.0,stddev=1.0),name=\'weight\')
15         bias=tf.Variable(tf.constant(0.0,shape=[10]))
16         #预测Nonew个样本的输出结果matrix [None,784]*[784*10]+[10]=[None,10],即矩阵[None,784]样本的特征*权重[784,10]+偏置[10]=预测结果[None,10]
17         y_predict=tf.matmul(x,weight)*bias
18     #计算交叉熵损失
19     with tf.compat.v1.variable_scope(\'soft_cross\'):
20         #返回交叉熵的列表结果,对交叉熵求平均值
21         loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
22 
23     #梯度下降求出损失
24     with tf.compat.v1.variable_scope(\'optimizer\'):
25         train_op=tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize(loss)
26     #5、计算准确率,预测准确置为1
27     with tf.compat.v1.variable_scope(\'acc\'):
28         #equal_list None个样本[1,0,1,1,.....]
29         equal_list=tf.equal(tf.argmax(y_true,1),tf.argmax(y_predict,1))
30         accuray=tf.reduce_mean(tf.cast(equal_list,tf.float32))
31     #收集变量,单个数字值收集
32     tf.compat.v1.summary.scalar("losses",loss)
33     tf.compat.v1.summary.scalar("acc", accuray)
34 
35     #高纬度变量收集
36     tf.compat.v1.summary.histogram(\'weight\',weight)
37     tf.compat.v1.summary.histogram(\'biases\',bias)
38 
39     #定义一个合并的op
40     merged=tf.compat.v1.summary.merge_all()
41 
42     #因为有变量,故要定义初始化变量的op
43     init_op=tf.compat.v1.global_variables_initializer()
44     #开启回话去训练
45     with tf.compat.v1.Session() as sess:
46         #初始化变量
47         sess.run(init_op)
48         filewriter=tf.compat.v1.summary.FileWriter(\'./tmp/summary/test/\',graph=sess.graph)
49         #迭代步数去训练 ,更新参数预测
50         for i in range(2000):
51             mnist_x,mnist_y=mnist.train.next_batch(50)
52             #feed_dict实时提供的数据 x训练集,y为真实的目标值
53             #运行op训练
54             sess.run(train_op,feed_dict={x:mnist_x,y_true:mnist_y})
55             #写入每步训练的值
56             summary=sess.run(merged,feed_dict={x:mnist_x,y_true:mnist_y})
57             filewriter.add_summary(summary,i)
58 
59             print(\'训练第%d步,准确率为:%f\'%(i,sess.run(accuray,feed_dict={x:mnist_x,y_true:mnist_y})))
60     return None

注意:在tensorflow2.X版本,如果出现报No module named \'tensorflow.examples.tutorials\' ,手动下载tutorials文件包,并放到本地电脑tersorflow/examples目录下。

下载链接:https://share.weiyun.com/fpYSBj4X 密码:qu73et

 

出现报错tensorflow报AttributeError: __enter__,将tf.compat.v1.Session后面加上括号()

 

 

在以上的代码基础上增加一下代码(红色字体)

FLAGS=tf.compat.v1.flags.FLAGS
tf.compat.v1.flags.DEFINE_integer("is_train",1,\'指定程序是预测还是训练\')


 #开启回话去训练
    with tf.compat.v1.Session() as sess:
        #初始化变量
        sess.run(init_op)
        filewriter=tf.compat.v1.summary.FileWriter(\'./tmp/summary/test/\',graph=sess.graph)
        #迭代步数去训练 ,更新参数预测
        if FLAGS.is_train ==1:
            for i in range(2000):
                mnist_x,mnist_y=mnist.train.next_batch(50)
                #feed_dict实时提供的数据 x训练集,y为真实的目标值
                #运行op训练
                sess.run(train_op,feed_dict={x:mnist_x,y_true:mnist_y})
                #写入每步训练的值
                summary=sess.run(merged,feed_dict={x:mnist_x,y_true:mnist_y})
                filewriter.add_summary(summary,i)

                print(\'训练第%d步,准确率为:%f\'%(i,sess.run(accuray,feed_dict={x:mnist_x,y_true:mnist_y})))
            #保存模型
            saver.save(sess,"./tmp/ckpt/tc_model")
        else:
            #加载模型,如果不加载模型,则参数不会被新的覆盖
            saver.restore(sess,\'./tmp/ckpt/tc_model\')
            #如果是0,做出预测
            for i in range(100):
                #每次测试一张图片[0,0,0,1,0,..0]
                x_test,y_test=mnist.test.next_batch(1)
                print(\'第%d张图片,手写数字目标是%d,预测结果是:%d\'%(i,
                                                   tf.argmax(y_test,1).eval(),
                                                   #图片与预测值的概率,再求其最大概率值
                                                  tf.argmax( sess.run(y_predict,feed_dict={x:x_test,y_true:y_test}),1).eval()))


    return None

 

 

 

cmd运行程序结果如下:

 

分类:

技术点:

相关文章:

  • 2022-12-23
  • 2022-02-06
  • 2021-05-15
  • 2021-10-14
  • 2021-06-24
  • 2023-01-31
  • 2021-05-06
  • 2021-06-24
猜你喜欢
  • 2021-05-14
  • 2021-06-16
  • 2022-01-12
  • 2022-12-23
  • 2022-01-19
  • 2022-12-23
  • 2022-12-23
相关资源
相似解决方案