【发布时间】:2017-06-30 04:49:00
【问题描述】:
在解决了我的输入形状问题后,我运行了我的程序,问题是程序打印的总损失太高了(例如,如果我将它与快速入门教程中的比较)。
我的目标是通过使用过去的数据来预测未来条目的拥塞情况(我有超过 10M 的条目并标记了分数)所以我应该不会有训练问题。
这是我的代码:
import numpy as np
import tflearn
# Load CSV file, indicate that the first column represents labels
from tflearn.data_utils import load_csv
data, labels = load_csv('nowcastScaled.csv', has_header=True, n_classes=2)
# Preprocessing function
def preprocess(data):
return np.array(data, dtype=np.float32)
# Preprocess data
data = preprocess(data)
# Build neural network
net = tflearn.input_data(shape=[None, 2])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 2, activation='linear')
data = np.reshape(data, (-1, 2))
labels = np.reshape(labels, (-1, 2))
net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
loss='categorical_crossentropy')
# Define model
model = tflearn.DNN(net)
# Start training (apply gradient descent algorithm)
model.fit(data, labels, n_epoch=15, batch_size=16, show_metric=True)
# Training
model.save('test_Model')
model.load('test_Model')
score = model.evaluate(data, labels, batch_size=16)
我的 excel 文件有这种外观(2 列,100 000 法分)
calculed_at , congestion
1 , 56
2 , 21
这是结果的样子(15 epoch):
Training samples: 50000
Validation samples: 0
....
--
Training Step: 40625 | total loss: 15.27961 | time: 17.659s
| Adam | epoch: 013 | loss: 15.27961 - acc: 0.7070 -- iter: 50000/50000
--
Training Step: 43750 | total loss: 15.66268 | time: 17.549s
| Adam | epoch: 014 | loss: 15.66268 - acc: 0.7247 -- iter: 50000/50000
--
Training Step: 46875 | total loss: 15.94696 | time: 18.037s
| Adam | epoch: 015 | loss: 15.94696 - acc: 0.7581 -- iter: 50000/50000
--
您知道什么会导致如此高的损失吗?这似乎很奇怪,因为打印的准确性似乎并不算太差。感谢您的帮助。
编辑:我接受这些值似乎是个好时机,因为当我刚刚尝试时,我的总损失超过 280(准确度低于 0.3 或略高于)。
【问题讨论】:
-
您介意共享 csv 文件进行调试吗?
-
我立即想到的一件事是,这是一个完全线性的网络。如果您查看 tflearn API,您的激活默认为
linear。先把它改成非线性,看看能不能用。另外,既然这是一个回归,为什么你的损失categorical_crossentropy?详细说明您的data和labels会有所帮助。 -
@jkschin 我不太习惯 tflearn,我尝试了“softmax”而不是线性,但效果并不好。对于 categorical_crossentropy,我使用它是因为它在快速入门教程中,我不知道我应该在什么基础上选择一个或另一个。详细阐述数据和标签是什么意思?
-
举几个例子说明
data和labels的作用。详细说明“您的 excel 文件具有这种外观”。 -
@jkschin 可以在here 找到包含所有信息的文件。我把它上传到了 mediafire,因为我不知道在 SO 中处理火灾的方法是什么。
标签: python machine-learning tensorflow tflearn