【发布时间】:2020-05-30 10:26:20
【问题描述】:
所以我在 MNIST 数据集上训练了一个 tensorflow 感知器,但只有数字 0 到 4。然后我制作了一个新模型,它具有所有相同的层和权重,但新的输出层也有 5 个输出节点。我想训练这个新模型对数字 5 到 9 进行分类。
我生成了一个只有数字 5 到 9 的新 x_train 和 y_train,然后运行
transfer_model.fit(x_train[train_filter],y_train[train_filter], epoch=5)
其中 train_filter 定义为 np.where(np.logical_and(x_train<=5,x_train>=9))。
在训练的第一步,我得到这个错误:
InvalidArgumentError:收到的标签值 9 超出了 [0, 5) 的有效范围。标签值:5 9 7 8 9 8 7 6 8 7 6 9 5 5 8 7 6 9 9 7 6 7 6 8 7 7 9 7 6 8 5 6
这是有道理的,因为我最初训练网络在 [0,5) 范围内进行分类,但现在我想在 [5,10) 范围内进行分类。我在这里错过了一步吗?我不确定我错过了什么......如何定义每个输出神经元对应的内容?
这是我的模型摘要:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_7 (Flatten) (None, 784) 0
_________________________________________________________________
dense_49 (Dense) (None, 100) 78500
_________________________________________________________________
batch_normalization_10 (Batc (None, 100) 400
_________________________________________________________________
dropout_5 (Dropout) (None, 100) 0
_________________________________________________________________
dense_50 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_11 (Batc (None, 100) 400
_________________________________________________________________
dropout_6 (Dropout) (None, 100) 0
_________________________________________________________________
dense_51 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_12 (Batc (None, 100) 400
_________________________________________________________________
dropout_7 (Dropout) (None, 100) 0
_________________________________________________________________
dense_52 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_13 (Batc (None, 100) 400
_________________________________________________________________
dropout_8 (Dropout) (None, 100) 0
_________________________________________________________________
dense_53 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_14 (Batc (None, 100) 400
_________________________________________________________________
dropout_9 (Dropout) (None, 100) 0
_________________________________________________________________
dense_55 (Dense) (None, 5) 505
=================================================================
Total params: 121,405
Trainable params: 505
Non-trainable params: 120,900
_________________________________________________________________
【问题讨论】:
标签: python tensorflow transfer-learning