【问题标题】:Online LSTM classification model giving very high number of wrong predictions在线 LSTM 分类模型给出非常多的错误预测
【发布时间】:2019-03-01 19:39:27
【问题描述】:

我正在尝试使用 20 个新闻组数据集来实现一个在线分类模型,以将帖子分类到相关组中。

预处理:我正在浏览所有帖子并用单词制作字典。然后我从 1 开始索引单词。然后我遍历所有帖子和每个单词在一篇文章中,我正在搜索词汇表并将相关的索引号放入一个数组中。然后我通过在末尾添加 0 来填充所有数组,使它们的大小都相同(6577)。

然后我正在创建嵌入层(嵌入大小=300)。并且每个输入在被馈送到 LSTM 层之前都会经过这个嵌入层(LSTM 输入形状= (1,6577,300))。

在我的模型中,我有一个 LSTM 层(大小 = 200)和一个隐藏层(大小 = 25)。为此,我在 tensorflow 中使用 dynamic_rnn 单元格,并将序列长度参数设置为帖子的实际长度(没有填充 0 的长度)以避免分析填充的 0。然后从 LSTM 层的输出中,我只将相关输出提供给隐藏层。

从那里开始,它就像一个普通的 LSTM 实现。我已经尽我所能提高模型的准确性,但错误预测的数量非常多:

数据点数:18,846
错误:17876
错误率:0.9485301920832007

注意:在反向传播期间,我正在训练嵌入层和隐藏层。

问题:我想知道我在这里做错了什么,或者有什么想法可以改进模型。提前谢谢你。

我的完整代码如下所示:

from collections import Counter
import tensorflow as tf
from sklearn.datasets import fetch_20newsgroups
import matplotlib as mplt
mplt.use('agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
from string import punctuation
from sklearn.preprocessing import LabelBinarizer
import numpy as np
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')



def pre_process():
    newsgroups_data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))

    words = []
    temp_post_text = []
    print(len(newsgroups_data.data))

    for post in newsgroups_data.data:

        all_text = ''.join([text for text in post if text not in punctuation])
        all_text = all_text.split('\n')
        all_text = ''.join(all_text)
        temp_text = all_text.split(" ")

        for word in temp_text:
            if word.isalpha():
                temp_text[temp_text.index(word)] = word.lower()

        # temp_text = [word for word in temp_text if word not in stopwords.words('english')]
        temp_text = list(filter(None, temp_text))
        temp_text = ' '.join([i for i in temp_text if not i.isdigit()])
        words += temp_text.split(" ")
        temp_post_text.append(temp_text)

    # temp_post_text = list(filter(None, temp_post_text))

    dictionary = Counter(words)
    # deleting spaces
    # del dictionary[""]
    sorted_split_words = sorted(dictionary, key=dictionary.get, reverse=True)
    vocab_to_int = {c: i for i, c in enumerate(sorted_split_words,1)}

    message_ints = []
    for message in temp_post_text:
        temp_message = message.split(" ")
        message_ints.append([vocab_to_int[i] for i in temp_message])


    # maximum message length = 6577

    # message_lens = Counter([len(x) for x in message_ints])AAA

    seq_length = 6577
    num_messages = len(temp_post_text)
    features = np.zeros([num_messages, seq_length], dtype=int)
    for i, row in enumerate(message_ints):
        # print(features[i, -len(row):])
        # features[i, -len(row):] = np.array(row)[:seq_length]
        features[i, :len(row)] = np.array(row)[:seq_length]
        # print(features[i])

    lb = LabelBinarizer()
    lbl = newsgroups_data.target
    labels = np.reshape(lbl, [-1])
    labels = lb.fit_transform(labels)

    sequence_lengths = [len(msg) for msg in message_ints]
    return features, labels, len(sorted_split_words)+1, sequence_lengths


def get_batches(x, y, sql, batch_size=1):
    for ii in range(0, len(y), batch_size):
        yield x[ii:ii + batch_size], y[ii:ii + batch_size], sql[ii:ii+batch_size]


def plot(noOfWrongPred, dataPoints):
    font_size = 14
    fig = plt.figure(dpi=100,figsize=(10, 6))
    mplt.rcParams.update({'font.size': font_size})
    plt.title("Distribution of wrong predictions", fontsize=font_size)
    plt.ylabel('Error rate', fontsize=font_size)
    plt.xlabel('Number of data points', fontsize=font_size)

    plt.plot(dataPoints, noOfWrongPred, label='Prediction', color='blue', linewidth=1.8)
    # plt.legend(loc='upper right', fontsize=14)

    plt.savefig('distribution of wrong predictions.png')
    # plt.show()



def train_test():
    features, labels, n_words, sequence_length = pre_process()

    print(features.shape)
    print(labels.shape)

    # Defining Hyperparameters

    lstm_layers = 1
    batch_size = 1
    lstm_size = 200
    learning_rate = 0.01

    # --------------placeholders-------------------------------------

    # Create the graph object
    graph = tf.Graph()
    # Add nodes to the graph
    with graph.as_default():

        tf.set_random_seed(1)

        inputs_ = tf.placeholder(tf.int32, [None, None], name="inputs")
        # labels_ = tf.placeholder(dtype= tf.int32)
        labels_ = tf.placeholder(tf.float32, [None, None], name="labels")
        sql_in = tf.placeholder(tf.int32, [None], name= 'sql_in')

        # output_keep_prob is the dropout added to the RNN's outputs, the dropout will have no effect on the calculation of the subsequent states.
        keep_prob = tf.placeholder(tf.float32, name="keep_prob")

        # Size of the embedding vectors (number of units in the embedding layer)
        embed_size = 300

        # generating random values from a uniform distribution (minval included and maxval excluded)
        embedding = tf.Variable(tf.random_uniform((n_words, embed_size), -1, 1),trainable=True)
        embed = tf.nn.embedding_lookup(embedding, inputs_)

        print(embedding.shape)
        print(embed.shape)
        print(embed[0])

        # Your basic LSTM cell
        lstm =  tf.contrib.rnn.BasicLSTMCell(lstm_size)

        # Getting an initial state of all zeros
        initial_state = lstm.zero_state(batch_size, tf.float32)

        outputs, final_state = tf.nn.dynamic_rnn(lstm, embed, initial_state=initial_state, sequence_length=sql_in)

        out_batch_size = tf.shape(outputs)[0]
        out_max_length = tf.shape(outputs)[1]
        out_size = int(outputs.get_shape()[2])
        index = tf.range(0, out_batch_size) * out_max_length + (sql_in - 1)
        flat = tf.reshape(outputs, [-1, out_size])
        relevant = tf.gather(flat, index)

        # hidden layer
        hidden = tf.layers.dense(relevant, units=25, activation=tf.nn.relu,trainable=True)

        print(hidden.shape)

        logit = tf.contrib.layers.fully_connected(hidden, num_outputs=20, activation_fn=None)

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=labels_))


        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)


        saver = tf.train.Saver()

    # ----------------------------online training-----------------------------------------

    with tf.Session(graph=graph) as sess:
        tf.set_random_seed(1)
        sess.run(tf.global_variables_initializer())
        iteration = 1
        state = sess.run(initial_state)
        wrongPred = 0
        noOfWrongPreds = []
        dataPoints = []

        for ii, (x, y, sql) in enumerate(get_batches(features, labels, sequence_length, batch_size), 1):

            feed = {inputs_: x,
                    labels_: y,
                    sql_in : sql,
                    keep_prob: 0.5,
                    initial_state: state}

            predictions = tf.nn.softmax(logit).eval(feed_dict=feed)

            print("----------------------------------------------------------")
            print("sez: ",sql)
            print("Iteration: {}".format(iteration))

            isequal = np.equal(np.argmax(predictions[0], 0), np.argmax(y[0], 0))

            print(np.argmax(predictions[0], 0))
            print(np.argmax(y[0], 0))

            if not (isequal):
                wrongPred += 1

            print("nummber of wrong preds: ",wrongPred)

            if iteration%50 == 0:
                noOfWrongPreds.append(wrongPred/iteration)
                dataPoints.append(iteration)

            loss, states, _ = sess.run([cost, outputs, optimizer], feed_dict=feed)

            print("Train loss: {:.3f}".format(loss))
            iteration += 1

        saver.save(sess, "checkpoints/sentiment.ckpt")
        errorRate = wrongPred / len(labels)
        print("ERRORS: ", wrongPred)
        print("ERROR RATE: ", errorRate)
        plot(noOfWrongPreds, dataPoints)


if __name__ == '__main__':
    train_test()

编辑

【问题讨论】:

    标签: python tensorflow machine-learning lstm text-classification


    【解决方案1】:

    需要考虑的几件事-:

    1. 绘制损失与迭代图。知道您的网络正在学习应该是向下的。您可以使用 tensorboard 来生成这些图表。还产生准确性与迭代。
    2. 将批量大小从 1 增加到 64,128 的小批量,具体取决于您的系统配置 (RAM)
    3. 使用双向 LSTM,因为您在训练模型之前有完整的句子以提高准确性。

    编辑

    您的模型没有正确学习权重。 运行您的代码,模型仅预测 0 类。看看您的预测和预测 1。预测总是 0。

    迭代:1 0 10 错误预测数:1

    火车损失:3.116

    迭代:2 0 3 错误预测数:2

    火车损失:3.163

    迭代:3 0 17 错误预测数:3

    火车损失:3.212

    迭代:4 0 3 错误预测数:4

    火车损失:2.992

    迭代:5 0 4 错误预测数:5

    火车损失:2.892

    迭代:6 0 12 错误预测数:6

    火车损失:3.077

    迭代:7 0 4 错误预测数:7

    火车损失:2.546

    迭代:8 0 10 错误预测数:8

    火车损失:3.459

    迭代:9 0 10 错误预测数:9

    火车损失:2.341

    迭代:10 0 19 错误预测数:10

    火车损失:3.303

    迭代:11 0 19 错误预测数:11

    火车损失:3.193

    迭代:12 0 11 错误预测数:12

    火车损失:3.323

    迭代:13 0 19 错误预测数:13

    火车损失:2.773

    迭代:14 0 13 错误预测数:14

    火车损失:3.129

    迭代:15 0 0 错误预测数:14

    火车损失:3.992

    迭代:16 0 17 错误预测数:15

    火车损失:3.010

    迭代:17 0 12 错误预测数:16

    火车损失:2.534

    迭代次数:18 0 12 错误预测数:17

    火车损失:2.804

    迭代次数:19 0 11 错误预测数:18

    火车损失:4.369

    迭代:20 0 8 错误预测数:19

    火车损失:4.028

    迭代次数:21 0 7 错误预测数:20

    火车损失:3.844

    迭代次数:22 0 5 错误预测数:21

    火车损失:3.579

    迭代次数:23 0 1 错误预测数:22

    火车损失:3.418

    迭代次数:24 0 8 错误预测数:23

    火车损失:4.337

    迭代次数:25 0 10 错误预测数:24

    火车损失:2.328

    迭代次数:26 0 14 错误预测数:25

    火车损失:4.216

    迭代次数:27 0 16 错误预测数:26

    火车损失:3.155

    迭代次数:28 0 1 错误预测数:27

    火车损失:3.307

    迭代次数:29 0 6 错误预测数:28

    火车损失:3.744

    迭代次数:30 0 0 错误预测数:28

    火车损失:4.180

    迭代次数:31 0 7 错误预测数:29

    火车损失:3.400

    迭代次数:32 0 16 错误预测数:30

    火车损失:2.706

    迭代次数:33 0 5 错误预测数:31

    火车损失:2.994

    迭代次数:34 0 9 错误预测数:32

    火车损失:3.610

    迭代次数:35 0 13 错误预测数:33

    火车损失:2.689

    迭代次数:36 0 4 错误预测数:34

    火车损失:2.755

    迭代次数:37 0 4 错误预测数:35

    火车损失:2.778

    迭代次数:38 0 18 错误预测数:36

    火车损失:3.361

    迭代次数:39 0 8 错误预测数:37

    火车损失:3.640

    迭代:40 0 8 错误预测数:38

    火车损失:3.276

    迭代次数:41 0 19 错误预测数:39

    火车损失:2.796

    迭代次数:42 0 1 错误预测数:40

    火车损失:3.189

    迭代次数:43 0 12 错误预测数:41

    火车损失:2.901

    迭代次数:44 0 7 错误预测数:42

    火车损失:2.913

    迭代次数:45 0 10 错误预测数:43

    火车损失:2.875

    迭代次数:46 0 5 错误预测数:44

    火车损失:3.005

    迭代次数:47 0 2 错误预测数:45

    火车损失:3.246

    迭代次数:48 0 6 错误预测数:46

    火车损失:3.071

    迭代次数:49 0 11 错误预测数:47

    火车损失:2.971

    迭代次数:50 0 2 错误预测数:48

    火车损失:3.192

    迭代次数:51 0 12 错误预测数:49

    火车损失:2.894

    迭代次数:52 0 7 错误预测数:50

    火车损失:2.980

    【讨论】:

    • 这里没有纪元(基本上是一个纪元)。这是一个在线学习实例。
    • 很抱歉我在复制和粘贴代码时犯了一个小错误。我现在修好了。它给出了不同的预测,但错误预测的数量太高了。
    • 我还绘制了错误率与它在编辑中显示的迭代(数据点)的关系
    猜你喜欢
    • 2019-02-23
    • 2019-02-25
    • 2018-06-12
    • 2020-04-16
    • 2021-03-12
    • 2022-07-28
    • 1970-01-01
    • 1970-01-01
    • 2021-12-28
    相关资源
    最近更新 更多