【问题标题】:Transfer Learning on MNIST: wrong labels errorMNIST 上的迁移学习:错误标签错误
【发布时间】:2020-05-30 10:26:20
【问题描述】:

所以我在 MNIST 数据集上训练了一个 tensorflow 感知器,但只有数字 0 到 4。然后我制作了一个新模型,它具有所有相同的层和权重,但新的输出层也有 5 个输出节点。我想训练这个新模型对数字 5 到 9 进行分类。

我生成了一个只有数字 5 到 9 的新 x_train 和 y_train,然后运行

transfer_model.fit(x_train[train_filter],y_train[train_filter], epoch=5)

其中 train_filter 定义为 np.where(np.logical_and(x_train<=5,x_train>=9))

在训练的第一步,我得到这个错误:

InvalidArgumentError:收到的标签值 9 超出了 [0, 5) 的有效范围。标签值:5 9 7 8 9 8 7 6 8 7 6 9 5 5 8 7 6 9 9 7 6 7 6 8 7 7 9 7 6 8 5 6

这是有道理的,因为我最初训练网络在 [0,5) 范围内进行分类,但现在我想在 [5,10) 范围内进行分类。我在这里错过了一步吗?我不确定我错过了什么......如何定义每个输出神经元对应的内容?

这是我的模型摘要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_7 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_49 (Dense)             (None, 100)               78500     
_________________________________________________________________
batch_normalization_10 (Batc (None, 100)               400       
_________________________________________________________________
dropout_5 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_50 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_11 (Batc (None, 100)               400       
_________________________________________________________________
dropout_6 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_51 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_12 (Batc (None, 100)               400       
_________________________________________________________________
dropout_7 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_52 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_13 (Batc (None, 100)               400       
_________________________________________________________________
dropout_8 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_53 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_14 (Batc (None, 100)               400       
_________________________________________________________________
dropout_9 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_55 (Dense)             (None, 5)                 505       
=================================================================
Total params: 121,405
Trainable params: 505
Non-trainable params: 120,900
_________________________________________________________________

【问题讨论】:

    标签: python tensorflow transfer-learning


    【解决方案1】:

    您需要将 5-9 映射到 0-4。类标签可能是通过一种热编码完成的,你有 5 个唯一标签,所以它只需要一个长度为 5 的向量来表示它。但由于标签是 5-9,它将超出范围。您不需要调整模型,只需将地图添加到标签输出。

    【讨论】:

    • 有道理,但是我应该在我的 keras 代码中的什么时候做这张地图?我认为 keras 有这方面的东西,但我不知道到底是什么。
    • 将地图查找添加到数据加载器部分,例如当您获取数据和标签时,将标签映射到相应的 0 索引值。然后在查看输出时,使用逆映射将输出映射回适当的标签。
    【解决方案2】:

    由于您使用numpy,您可以尝试以下方法

    import tensorflow as tf
    import numpy as np
    
    arr = np.array([5,6,7,8,9,8,7,6,5])
    arr = tf.one_hot(arr,10,axis=0).numpy()
    arr = arr[5:]
    
    tf.argmax(arr).numpy() # returns array([0, 1, 2, 3, 4, 3, 2, 1, 0])
    

    或使用tf.map_fn

    arr = np.array([5,6,7,8,9,8,7,6,5])
    
    tf.map_fn(lambda x : x-5, arr).numpy() # array([0, 1, 2, 3, 4, 3, 2, 1, 0])
    

    【讨论】:

      猜你喜欢
      • 2018-08-04
      • 1970-01-01
      • 1970-01-01
      • 2017-04-09
      • 1970-01-01
      • 2021-11-12
      • 1970-01-01
      • 2018-05-30
      • 1970-01-01
      相关资源
      最近更新 更多