【问题标题】:How to do leave one out cross validation with tensor-flow (Keras)?如何使用张量流(Keras)进行交叉验证?
【发布时间】:2020-03-04 13:37:45
【问题描述】:

我有 20 个科目,我想在训练使用 Tensorflow 实现的模型时使用留一法交叉验证。我按照一些说明进行操作,最后这是我的伪代码:

for train_index, test_index in loo.split(data):
print("TRAIN:", train_index, "TEST:", test_index)
train_X=np.concatenate(np.array([data[ii][0] for ii in train_index]))
train_y=np.concatenate(np.array([data[ii][1] for ii in train_index]))

test_X=np.concatenate(np.array([data[ii][0] for ii in test_index]))
test_y=np.concatenate(np.array([data[ii][1] for ii in test_index]))


train_X,train_y = shuffle(train_X, train_y)
test_X,test_y = shuffle(test_X, test_y)



#Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)

#Train the model
model.fit(train_X, train_y, batch_size=16, epochs=30,validation_split=.10)#,callbacks=[])

#test accuracy
test_loss, test_acc = model.evaluate(test_X,test_y)
print('\nTest accuracy:', test_acc)

但是第一个主题之后的结果是这样的:

Epoch 30/30
3590/3590 [==============================] - 4s 1ms/sample - loss: 0.5976 - 
**acc: 0.8872** - val_loss: 1.3873 - val_acc: 0.6591


255/255 [==============================] - 0s 774us/sample - loss: 1.8592 - 
acc: 0.4471

Test accuracy: 0.44705883

第二次迭代(主题):

TRAIN: [ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17] TEST: [1]

Train on 3582 samples, validate on 398 samples
Epoch 1/30
3582/3582 [==============================] - 5s 1ms/sample - loss: 0.7252 - 
**acc: 0.8238** - val_loss: 1.0627 - val_acc: 0.6859

听起来模型使用了以前的权重!如果我们看第二次迭代的第一个精度,它从 acc: 0.8238 开始!

我的实现是否正确?或者我需要更多步骤来为每个主题设置初始权重?

【问题讨论】:

    标签: python tensorflow machine-learning conv-neural-network cross-validation


    【解决方案1】:

    0.8238 是训练数据,而不是您的测试数据。您的 fit() 方法还对训练数据进行了验证拆分。

    就我所见,该模型运行良好。你的实现是正确的。

    【讨论】:

    • 不确定!对于每个科目,训练应该独立于其他科目,但事实并非如此!
    【解决方案2】:

    在 for 循环中编译之前,您需要添加以下行:

    tf.keras.backend.clear_session()
    

    这将删除 Tensorflow 存储的所有图表和会话信息,包括您的图表权重。您可以查看源代码here 以及它的作用说明here

    【讨论】:

      猜你喜欢
      • 2020-04-19
      • 2017-05-04
      • 2017-04-12
      • 2022-01-25
      • 2021-09-30
      • 2020-11-30
      • 2017-06-29
      • 2017-04-10
      • 2015-06-11
      相关资源
      最近更新 更多