最近基于bi-lstm做了一个辱骂识别模型准备部署到线上,之前打算用python 启动一个service 通过http请求来调用,发现公司平台是基于rpc服务的,开发部署起来也较蛋疼,今天下午闲来没事,看到tensorflow中有提供官方例子,通过python中训练好模型,用java来调用,刚刚好摸索了下,动手写了下代码,总算能在java中调用,废话不多说,直接看代码实现情况。
tensorflow版本情况:
In [1]: import tensorflow as tfIn [2]: tf.__version__Out[2]: '1.2.1'
java需要1.8的版本
maven依赖:
<dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow</artifactId><version>1.2.1</version></dependency>
参考资料:
tensorflow训练模型时候要保存的模型参数,主要有是三个,一个是模型输入的tensor大小,一个是dropout参数,一个是模型预测的logits(score/pred_y 表示name_scope下的pred_y)值,也就是y;模型保存为一个二进制文件,可以在java中加载:
if i%500==0 and i>0:graph = tf.graph_util.convert_variables_to_constants(session, session.graph_def,["keep_prob", "input_x", "score/pred_y"])tf.train.write_graph(graph, ".", "/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph/graph.db",as_text=False)
java代码如下,其中gettexttoid方法参考tensorflow中 tensorflow.contrib.keras.preprocessing.sequence.pad_sequences下的实现,用于做文本预测:
package com.meituan.test;import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStreamReader;import java.nio.ByteBuffer;import java.nio.ByteOrder;import java.nio.IntBuffer;import java.nio.file.Files;import java.nio.file.Paths;import java.nio.file.Path;import java.util.ArrayList;import java.util.Arrays;import java.util.Collection;import java.util.HashMap;import java.util.List;import java.util.Map;import org.apache.commons.io.FileUtils;import org.apache.commons.lang.StringUtils;import org.tensorflow.Graph;import org.tensorflow.Session;import org.tensorflow.Tensor;public class TensorflowEx {private static String path = "/Users/shuubiasahi/Documents/python/credit-tftextclassify-abuse/vocab_cnews.txt";private static Map<String, Integer> word_to_id = new HashMap<String, Integer>();static {try {BufferedReader buffer = null;buffer = new BufferedReader(new InputStreamReader(new FileInputStream(path)));int i=0;String line=buffer.readLine().trim();while(line!=null){word_to_id.put(line, i++);line=buffer.readLine().trim();}buffer.close();} catch (Exception e) {}System.out.println("word_to_id.size is:"+word_to_id.size());}public static void main(String[] args) {byte[] graphDef = readAllBytesOrExit(Paths.get("/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph","graph.db"));Graph g = new Graph();g.importGraphDef(graphDef);Session sess = new Session(g);String text="艹你麻痹的垃圾店家,劳资点的香干回锅肉套餐,你他麻痹炒个香干炒肉过来凑数,套餐内所有的东西都没看到,还尼玛口口声声说退款?退你麻痹,留着给你家人买棺材用吧,狗日的东西!";int[][] arr=gettexttoid(text);Tensor input = Tensor.create(arr);Tensor x = Tensor.create(1.0f);Tensor result = sess.runner().feed("input_x", input).feed("keep_prob", x).fetch("score/pred_y").run().get(0);long[] rshape = result.shape();/** 模型为一个二分类模型,故nlabels=2,模型预测一条数据故batchsize=1* 预测出来是一个1*2的数组,一条数据有两个概率** */int nlabels = (int) rshape[1];int batchSize = (int) rshape[0];float[][] logits = result.copyTo(new float[batchSize][nlabels]);System.out.println("辱骂模型识别的概率为:"+logits[0][1]);System.out.println("sucess");}private static byte[] readAllBytesOrExit(Path path) {try {return Files.readAllBytes(path);} catch (IOException e) {System.err.println("Failed to read [" + path + "]: "+ e.getMessage());System.exit(1);}return null;}/** 序列默人长度为300* */public static int[][] gettexttoid(String text){int[][] xpad = new int[1][300];if(StringUtils.isBlank(text)){return xpad;}char[] chs=text.trim().toLowerCase().toCharArray();List<Integer> list=new ArrayList<Integer>();for(int i=0;i<chs.length;i++){String element=Character.toString(chs[i]);if(word_to_id.containsKey(element)){list.add(word_to_id.get(element));}}if(list.size()==0){return xpad;}int size = list.size();Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();if(size<=300){System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);}else{System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);}return xpad;}/** 自定义长度* */public static int[][] gettexttoid(String text,int maxlen){if(maxlen<1){throw new IllegalArgumentException("maxlen长度必须大于等于1");}int[][] xpad = new int[1][maxlen];if(StringUtils.isBlank(text)){return xpad;}char[] chs=text.trim().toLowerCase().toCharArray();List<Integer> list=new ArrayList<Integer>();for(int i=0;i<chs.length;i++){String element=Character.toString(chs[i]);if(word_to_id.containsKey(element)){list.add(word_to_id.get(element));}}if(list.size()==0){return xpad;}int size = list.size();Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();if(size<=maxlen){System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);}else{System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);}return xpad;}}
结果对比:
java结果:
python启动的service结果:
结果一致,下周计划写个java service项目,把模型部署上线。
不过我碰到过问题,在java中做预测,1秒最多只能预测十来条文本,这感觉太慢了,不知道什么原因,我机器用的cpu,不知道是否要用gpu做预测,有知道的告诉我
联系我 xuxu_ge