Ref: http://blog.csdn.net/mebiuw/article/details/60780813

Ref: https://medium.com/@erikhallstrm/hello-world-rnn-83cd7105b767 [Nice]

Ref: https://medium.com/@erikhallstrm/tensorflow-rnn-api-2bb31821b185 [Nice]

 

Code Analysis 

Download and pre-preprocess

# Implementing an RNN in Tensorflow
#----------------------------------
#
# We implement an RNN in Tensorflow to predict spam/ham from texts
#
# Jeffrey: the data process for nlp here is advanced.

import os
import re
import io
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from zipfile import ZipFile
import urllib.request

from tensorflow.python.framework import ops
ops.reset_default_graph()

# Start a graph
sess = tf.Session()

# Set RNN parameters
epochs              = 30
batch_size          = 250
max_sequence_length = 40
rnn_size            = 10
embedding_size      = 50
min_word_frequency  = 10
learning_rate       = 0.0005
dropout_keep_prob   = tf.placeholder(tf.float32)


# Download or open data
data_dir = 'temp'
data_file = 'text_data.txt'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

if not os.path.isfile(os.path.join(data_dir, data_file)):
    zip_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
    page = urllib.request.urlopen(zip_url)
    html_content = page.read()
    z = ZipFile(io.BytesIO(html_content))
    
    file = z.read('SMSSpamCollection')
    
    # Format Data
    text_data = file.decode()
    text_data = text_data.encode('ascii',errors='ignore')
    text_data = text_data.decode().split('\n')

    # Save data to text file
    with open(os.path.join(data_dir, data_file), 'w') as file_conn:
        for text in text_data:
            file_conn.write("{}\n".format(text))
else:
    # Open data from text file
    text_data = []
    with open(os.path.join(data_dir, data_file), 'r') as file_conn:
        for row in file_conn:
            text_data.append(row)
    text_data = text_data[:-1]

text_data = [x.split('\t') for x in text_data if len(x)>=1]
[text_data_target, text_data_train] = [list(x) for x in zip(*text_data)]


# Create a text cleaning function
def clean_text(text_string):
    text_string = re.sub(r'([^\s\w]|_|[0-9])+', '', text_string)
    text_string = " ".join(text_string.split())
    text_string = text_string.lower()
    return(text_string)
    
# Clean texts
text_data_train = [clean_text(x) for x in text_data_train]

#Jeffrey
#print("[x]:", text_data_train[:10][:10])
#print("[y]:", text_data_target[:10])
View Code

相关文章: