【发布时间】:2017-10-13 05:47:43
【问题描述】:
我一直在wikigold.conll NER data set上运行this LSTM tutorial
training_data 包含序列和标签的元组列表,例如:
training_data = [
("They also have a song called \" wake up \"".split(), ["O", "O", "O", "O", "O", "O", "I-MISC", "I-MISC", "I-MISC", "I-MISC"]),
("Major General John C. Scheidt Jr.".split(), ["O", "O", "I-PER", "I-PER", "I-PER"])
]
我写下了这个函数
def predict(indices):
"""Gets a list of indices of training_data, and returns a list of predicted lists of tags"""
for index in indicies:
inputs = prepare_sequence(training_data[index][0], word_to_ix)
tag_scores = model(inputs)
values, target = torch.max(tag_scores, 1)
yield target
这样我可以得到训练数据中特定索引的预测标签。
但是,我如何评估所有训练数据的准确度得分。
准确度是,所有句子中正确分类的单词数量除以单词数。
这是我想出来的,极其缓慢和丑陋:
y_pred = list(predict([s for s, t in training_data]))
y_true = [t for s, t in training_data]
c=0
s=0
for i in range(len(training_data)):
n = len(y_true[i])
#super ugly and ineffiicient
s+=(sum(sum(list(y_true[i].view(-1, n) == y_pred[i].view(-1, n).data))))
c+=n
print ('Training accuracy:{a}'.format(a=float(s)/c))
如何在 pytorch 中有效地做到这一点?
附注: 我一直在尝试使用sklearn's accuracy_score 失败
【问题讨论】:
-
你能举个训练数据的例子吗?
-
数据是“wikigold.conll”(链接:downloads.schwa.org/wikiner/wikigold.conll.txt),将其解析为张量的完整代码在这里:pastebin.com/NxuRwh7D
-
CoNLL 数据链接不再工作。
标签: python scikit-learn deep-learning pytorch