【发布时间】:2017-02-11 17:54:23
【问题描述】:
我已尽力遵循有关神经网络结构的在线指南,但我肯定遗漏了一些基本知识。给定一组经过训练的权重及其偏差,我想简单地使用这些权重手动预测输入,而不使用 predict 方法。
使用带有 keras 的 MNIST 图像我尝试手动编辑我的数据以包含一个额外的偏差特征,但是这种努力似乎没有提供比完全不使用偏差更好的图像精度,而且绝对比使用精度低得多keras 预测方法。我的代码和我的尝试一起在下面。
请注意底部附近的两个 cmets 使用 keras 方法预测来获得准确的图像表示,然后我尝试手动获取权重并添加偏差。
from keras.datasets import mnist
import numpy as np
import time
from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf
from matplotlib import pyplot as plt
comptime=time.time()
with tf.device('/cpu:0'):
tf.placeholder(tf.float32, shape=(None, 20, 64))
seed = 7
np.random.seed(seed)
model = Sequential()
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
priorShape_x_train=x_train.shape #prior shape of training set
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
x_train_shaped=x_train
model.add(Dense(32, input_dim=784, init='uniform', activation='relu'))
model.add(Dense(784, init='uniform', activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])
model.fit(x_train[1:2500], x_train[1:2500], nb_epoch=10)
#proper keras prediction
prediction_real=model.predict(x_train[57:58])
prediction_real=prediction_real.reshape((28,28))
#manual weight prediction attempt
x_train=np.hstack([x_train,np.zeros(x_train.shape[0]).reshape(x_train.shape[0],1)]) #add extra column for bias
x_train[:,-1]=1 #add placeholder as 1
weights=np.vstack([model.get_weights()[0],model.get_weights()[1]]) #add trained weights as extra row vector
prediction=np.dot(x_train,weights) #now take dot product.. repeat pattern for next layer
prediction=np.hstack([prediction,np.zeros(prediction.shape[0]).reshape(prediction.shape[0],1)])
prediction[:,-1]=1
weights=np.vstack([model.get_weights()[2],model.get_weights()[3]])
prediction=np.dot(prediction,weights)
prediction=prediction.reshape(priorShape_x_train)
plt.imshow(prediction[57], interpolation='nearest',cmap='gray')
plt.savefig('myprediction.png') #my prediction, not accurate
plt.imshow(prediction_real,interpolation='nearest',cmap='gray')
plt.savefig('realprediction.png') #in-built keras method, accurate
【问题讨论】:
标签: neural-network keras