【问题标题】:Implement TFlearn imdb lstm example by tensorflow通过tensorflow实现TFlearn imdb lstm示例
【发布时间】:2018-05-02 01:04:07
【问题描述】:

我正在通过 tensorflow 实现 tflearn 的 lstm imdb example

我使用与 tflearn 模型相同的数据集、架构和超参数(嵌入大小、句子的最大长度等),但我的模型的性能比 tflearn 示例差(经过 10 个 epoch,我的模型大约52% 的准确率,而示例接近 80%)。

如果您能给我一些建议以实现示例的适当性能,我将不胜感激。

下面是我的代码:

import tensorflow as tf
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.datasets import imdb
from tensorflow.contrib.rnn import BasicLSTMCell
import time



n_class = 2
n_words = 10000
EMBEDDING_SIZE = 128
HIDDEN_SIZE = 128
MAX_LENGTH = 100
lr = 1e-3

epoch = 10
TRAIN_SIZE = 22500
validation_size = 2500
batch_size = 128
KP = 0.8

# IMDB Dataset loading
train, test, _ = imdb.load_data(path='imdb.pkl', n_words=n_words,
                                valid_portion=0.1, sort_by_len=False)
trainX, trainY = train
validationX, validationY = test
testX, testY = _


# Data preprocessing
# Sequence padding
trainX = pad_sequences(trainX, maxlen=MAX_LENGTH, value=0.)
validationX = pad_sequences(validationX, maxlen=MAX_LENGTH, value=0.)
testX = pad_sequences(testX, maxlen=MAX_LENGTH, value=0.)

# Converting labels to binary vectors
trainY = to_categorical(trainY, n_class)
validationY = to_categorical(validationY, n_class)
testY = to_categorical(testY, n_class)

graph = tf.Graph()
with graph.as_default():
    # input
    text = tf.placeholder(tf.int32, [None, MAX_LENGTH])
    labels = tf.placeholder(tf.float32, [None, n_class])
    keep_prob = tf.placeholder(tf.float32)

    embeddings_var = tf.Variable(tf.truncated_normal([n_words, EMBEDDING_SIZE]), trainable=True)
    text_embedded = tf.nn.embedding_lookup(embeddings_var, text)

    print(text_embedded.shape)  # [batch_size, length, embedding_size]
    word_list = tf.unstack(text_embedded, axis=1)

    cell = BasicLSTMCell(HIDDEN_SIZE)
    dropout_cell = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=keep_prob, output_keep_prob=keep_prob)
    outputs, encoding = tf.nn.static_rnn(dropout_cell, word_list, dtype=tf.float32)

    logits = tf.layers.dense(outputs[-1], n_class, activation=None)

    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
    optimizer = tf.train.AdamOptimizer(lr).minimize(loss)

    prediction = tf.argmax(logits, 1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, tf.argmax(labels, 1)), tf.float32))


train_steps = epoch * TRAIN_SIZE // batch_size + 1
print("Train steps: ", train_steps)


with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()
    print("Initialized!")
    s = time.time()
    offset = 0

    for step in range(train_steps):
        offset = (offset * step) % (TRAIN_SIZE - batch_size)
        batch_text = trainX[offset: offset + batch_size, :]
        batch_label = trainY[offset: offset + batch_size, :]
        fd = {text: batch_text, labels: batch_label, keep_prob: KP}
        _, l, acc = sess.run([optimizer, loss, accuracy], feed_dict=fd)

        if step % 100 == 0:
            print("Step: %d  loss: %f  accuracy: %f" % (step, l, acc))

        if step % 500 == 0:
            v_l, v_acc = sess.run([loss, accuracy], feed_dict={
                text: validationX,
                labels: validationY,
                keep_prob: 1.0
            })
            print("------------------------------------------------")
            print("Validation:  step: %d  loss: %f  accuracy: %f" % (step, v_l, v_acc))
            print("------------------------------------------------")
    print("Training finished, time consumed:", time.time() - s, " s")
    print("Test accuracy: %f" % accuracy.eval(feed_dict={
        text: testX,
        labels: testY,
        keep_prob: 1.0
    }))

【问题讨论】:

    标签: tensorflow nlp lstm sentiment-analysis imdb


    【解决方案1】:

    对不起,我犯了一个愚蠢的错误! 损失:

    损失 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

    应该是

    损失 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))

    那么,准确率就像 tflearn 的例子

    【讨论】:

      猜你喜欢
      • 2017-01-19
      • 1970-01-01
      • 2017-07-18
      • 2017-01-08
      • 1970-01-01
      • 2019-06-13
      • 2019-11-13
      • 2016-11-01
      • 2021-08-08
      相关资源
      最近更新 更多