首先,使用Keras创建神经网络主要分为四部分:
1.定义训练测试数据:输入张量和目标张量。
2.定义网络(模型),将输入映射到目标。
3.配置学习过程:选择损失函数,优化器,和监控指标。
4.fit方法进行训练迭代。
5.evaluate测试。
6.predict评估。
基本过程:
1.定义数据:
2.定义网络:
1.Sequential类(层的线性堆叠)常用。
from keras import models
from keras import layers
model=models.Sequential()
model.add(layers.Dense(32,activation='relu',inpute_shape=(784,)))
(创建了一个层,只接受第一个维度大小为784 的 2D 张量(第0 轴是批量维度,其大小没有指定,因此可以任意取值)作为输入。这个层将返回一个张量,第一个维度的大小变成 了 32。 )
(,逗号从外往里数)
model.add(layers.Dense(10, activation='softmax')) //10为类别,10维向量
//如果要对 N 个类别的数据点进行分类,网络的最后一层应该是大小为 N 的 Dense 层。
//对于回归问题来说,只需要为1,不需要进行softmax。
2.函数式API(有向无环图).
3.配置学习过程:
model.compile(optimizer=optimizers.RMSprop(lr=0.001), / 或者:optimizer='rmsprop'
loss='mse'
metrics=['accuracy'] )
4.训练:
model.fit(input_tensor, target_tensor, batch_size=128, epochs=10)
5.测试:
results=model.evaluate(x_test,y_test)
6.预测:
model.predict(x_test)
具体选择:
简单的向量数据保存在 形状为 (samples, features) 的 2D 张量中,通常用密集连接层[densely connected layer,也 叫全连接层(fully connected layer)或密集层(dense layer),对应于Keras 的 Dense 类]来处理。序列数据保存在形状为 (samples, timesteps, features) 的 3D 张量中,通常用循环 层(recurrent layer,比如Keras 的 LSTM 层)来处理。图像数据保存在4D 张量中,通常用二维 卷积层(Keras 的 Conv2D)来处理。
损失函数:对于二分类问题,你可以使用二元交叉熵(binary crossentropy)损失函数;对于多分类问题,可以用分类交叉熵(categorical crossentropy)损失函数;对于回归 问题,可以用均方误差(mean-squared error)损失函数;对于序列学习问题,可以用联结主义 时序分类(CTC,connectionist temporal classification)损失函数。
无论你的问题是什么,rmsprop 优化器通常都是足够好的选择。
绘图小知识:
import matplotlib.pyplot as plt
import numpy as np
x=np.array([1,3,5,2,7])
y=np.array([2,5,7,9,0])
plt.clf() //清空图像
plt.plot(x,y,label='标签')
plt.title('测试')
plt.xlabel('横坐标')
plt.legend() //加图例
plt.show() //显示
无监督知识:
定义:在没有标签的情况下找到数据的变换。
主要是进行数据分析,降维和聚类都是无监督学习方法。(原理都为函数的映射)
机器学习领域中所谓的降维就是指采用某种映射方法,将原高维空间中的数据点映射到低维度的空间中。
聚类:K-means:随机选取初始点,计算距离,归类,再重新划分中心。重复。
强化学习:
智能体(agent)接收有关其环境的信息,并学会选择使某种奖励最大化的行为。
为什么不是两个集合:一个训练集和一个测试集?还需要验证集?
原因在于开发模型时总是需要调节模型配置,比如选择层数或每层大小[这叫作模型的超 参数(hyperparameter),以便与模型参数(即权重)区分开]。这个调节过程需要使用模型在验 证数据上的性能作为反馈信号。这个调节过程本质上就是一种学习:在某个参数空间中寻找良 好的模型配置。因此,如果基于模型在验证集上的性能来调节模型配置,会很快导致模型在验 证集上过拟合,即使你并没有在验证集上直接训练模型也会如此。 会导致在新的数据上出现错误。