【问题标题】:Learn checksum rule with Keras使用 Keras 学习校验和规则
【发布时间】:2022-01-25 05:12:21
【问题描述】:

我有一组数据,前面有固定的 9 个数字,最后一个位置有 1 个数字作为未知规则中的校验和。我尝试建立一个学习模型来使用 Keras 找出它,但未能训练。

所以我使用校验和为 mod 10 的特定规则生成测试数据,但仍然无法训练。我对这 9 个数字进行 one-hot 编码,将数据集形成 (N,9,10) 的形状,然后发送到密集层,同时损失了交叉熵。

这是我的代码:

import numpy as np
from keras import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.utils import to_categorical

# generate test data
test_input = []
test_output = []
for _ in range(10000):
    value = int(np.round(np.random.rand()*1E9,0))
    chk = value % 10
    no = str(value).rjust(9, '0')
    test_input.append(no)
    test_output.append(chk)

test_input = [[int(s) for s in c_no] for c_no in test_input]
test_input = to_categorical(test_input)
test_output = to_categorical(test_output)

# build model
model = Sequential()
model.add(Dense(50, input_shape=(9, 10), activation='relu'))
model.add(Dense(30, activation='relu'))
model.add(Dense(20, activation='relu'))
model.add(Flatten())
model.add(Dense(10))
model.summary()

# train model
epoch_num = 20
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
history = model.fit(test_input, test_output, epochs=epoch_num, verbose=2, batch_size=50, validation_split=0.2)

但是,即使使用像这样简单的校验和规则,我的模型仍然无法成功训练。损失没有减少,准确率保持在0.1左右。

我想知道我犯了什么错误,谢谢!

【问题讨论】:

  • 不要更新您的问题,使现有答案变得无关紧要(回滚)。改为打开一个问题(它是免费的!)。

标签: python machine-learning keras checksum


【解决方案1】:

你只需要在最后一个dense layer添加softmax激活函数

model.add(Dense(10, activation= 'softmax'))

就这样,我在第 14 个 epoch 得到了 1.0 的准确率,跟随输出

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_8 (Dense)             (None, 9, 50)             550       
                                                                 
 dense_9 (Dense)             (None, 9, 30)             1530      
                                                                 
 dense_10 (Dense)            (None, 9, 20)             620       
                                                                 
 flatten_2 (Flatten)         (None, 180)               0         
                                                                 
 dense_11 (Dense)            (None, 10)                1810      
                                                                 
=================================================================
Total params: 4,510
Trainable params: 4,510
Non-trainable params: 0
_________________________________________________________________
Epoch 1/20
160/160 - 1s - loss: 2.2983 - accuracy: 0.1203 - val_loss: 2.2850 - val_accuracy: 0.1520 - 865ms/epoch - 5ms/step
Epoch 2/20
160/160 - 0s - loss: 2.2735 - accuracy: 0.1605 - val_loss: 2.2584 - val_accuracy: 0.1925 - 337ms/epoch - 2ms/step
Epoch 3/20
160/160 - 0s - loss: 2.2449 - accuracy: 0.2153 - val_loss: 2.2252 - val_accuracy: 0.2585 - 343ms/epoch - 2ms/step
Epoch 4/20
160/160 - 0s - loss: 2.2050 - accuracy: 0.2892 - val_loss: 2.1759 - val_accuracy: 0.3275 - 339ms/epoch - 2ms/step
Epoch 5/20
160/160 - 0s - loss: 2.1401 - accuracy: 0.3769 - val_loss: 2.0904 - val_accuracy: 0.4140 - 342ms/epoch - 2ms/step
Epoch 6/20
160/160 - 0s - loss: 2.0201 - accuracy: 0.4697 - val_loss: 1.9275 - val_accuracy: 0.5145 - 330ms/epoch - 2ms/step
Epoch 7/20
160/160 - 0s - loss: 1.7985 - accuracy: 0.5654 - val_loss: 1.6385 - val_accuracy: 0.6090 - 339ms/epoch - 2ms/step
Epoch 8/20
160/160 - 0s - loss: 1.4392 - accuracy: 0.6821 - val_loss: 1.2205 - val_accuracy: 0.7570 - 321ms/epoch - 2ms/step
Epoch 9/20
160/160 - 0s - loss: 1.0012 - accuracy: 0.8110 - val_loss: 0.7857 - val_accuracy: 0.8690 - 336ms/epoch - 2ms/step
Epoch 10/20
160/160 - 0s - loss: 0.6177 - accuracy: 0.9005 - val_loss: 0.4681 - val_accuracy: 0.9380 - 324ms/epoch - 2ms/step
Epoch 11/20
160/160 - 0s - loss: 0.3714 - accuracy: 0.9532 - val_loss: 0.2821 - val_accuracy: 0.9745 - 331ms/epoch - 2ms/step
Epoch 12/20
160/160 - 0s - loss: 0.2273 - accuracy: 0.9834 - val_loss: 0.1801 - val_accuracy: 0.9910 - 339ms/epoch - 2ms/step
Epoch 13/20
160/160 - 0s - loss: 0.1420 - accuracy: 0.9980 - val_loss: 0.1104 - val_accuracy: 0.9995 - 356ms/epoch - 2ms/step
Epoch 14/20
160/160 - 0s - loss: 0.0914 - accuracy: 1.0000 - val_loss: 0.0733 - val_accuracy: 1.0000 - 334ms/epoch - 2ms/step
Epoch 15/20
160/160 - 0s - loss: 0.0620 - accuracy: 1.0000 - val_loss: 0.0520 - val_accuracy: 1.0000 - 327ms/epoch - 2ms/step
Epoch 16/20
160/160 - 0s - loss: 0.0447 - accuracy: 1.0000 - val_loss: 0.0390 - val_accuracy: 1.0000 - 331ms/epoch - 2ms/step
Epoch 17/20
160/160 - 0s - loss: 0.0340 - accuracy: 1.0000 - val_loss: 0.0302 - val_accuracy: 1.0000 - 324ms/epoch - 2ms/step
Epoch 18/20
160/160 - 0s - loss: 0.0269 - accuracy: 1.0000 - val_loss: 0.0245 - val_accuracy: 1.0000 - 337ms/epoch - 2ms/step
Epoch 19/20
160/160 - 0s - loss: 0.0220 - accuracy: 1.0000 - val_loss: 0.0204 - val_accuracy: 1.0000 - 319ms/epoch - 2ms/step
Epoch 20/20
160/160 - 0s - loss: 0.0185 - accuracy: 1.0000 - val_loss: 0.0173 - val_accuracy: 1.0000 - 330ms/epoch - 2ms/step

【讨论】:

    猜你喜欢
    • 2018-11-30
    • 2020-12-19
    • 2011-12-22
    • 2014-04-02
    • 2017-02-19
    • 2017-06-29
    • 2015-04-02
    • 2017-09-30
    • 1970-01-01
    相关资源
    最近更新 更多