【问题标题】:Tensorflow Java API set Placeholder for categorical columnsTensorflow Java API 为分类列设置占位符
【发布时间】:2018-05-08 16:58:29
【问题描述】:

我想使用 Java API 从 Python Tensorflow API 对我训练有素的模型进行预测,但在输入要在 Java 中预测的特征时遇到问题。

我的 Python 代码是这样的:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from six.moves.urllib.request import urlopen
import numpy as np
import tensorflow as tf

feature_names = [
'Attribute1',
'Attribute2',
'Attribute3',
'Attribute4',
'Attribute5',
'Attribute6',
'Attribute7',
'Attribute8',
'Attribute9',
'Attribute10',
'Attribute11',
'Attribute12',
'Attribute13',
'Attribute14',
'Attribute15',
'Attribute16',
'Attribute17',
'Attribute18',
'Attribute19',
'Attribute20']

#prediction_input = np.array([['A11', 6, 'A34', 'A43', 1169, 'A65', 'A75',     4, 'A93', 'A101', 4, 'A121', 67, 'A143', 'A152', 2, 'A173', 1, 'A192', 'A201'],
#                               ['A12', 18, 'A34', 'A43', 1795, 'A61', 'A75', 3, 'A92', 'A103', 4, 'A121', 48, 'A141', 'A151', 2, 'A173', 1, 'A192', 'A201']])
prediction_input = [["A12 12 A32 A40 7472 A65 A71 1 A92 A101 2 A121 24 A143 A151 1 A171 1 A191 A201"],
                ["A11 36 A32 A40 9271 A61 A74 2 A93 A101 1 A123 24 A143 A152 1 A173 1 A192 A201"],
                ["A12 15 A30 A40 1778 A61 A72 2 A92 A101 1 A121 26 A143 A151 2 A171 1 A191 A201"]]

def predict_input_fn():
def decode(zeile):
    parsed_line = tf.decode_csv(zeile, [[''], [0], [''], [''], [0], [''], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0], [''], [0], [''], ['']], field_delim=' ')
    #x = tf.split(x, 20) # Need to split into our 20 features
    # When predicting, we don't need (or have) any labels
    return dict(zip(feature_names, parsed_line)) # Then build a dict from them

# The from_tensor_slices function will use a memory structure as input
dataset = tf.data.Dataset.from_tensor_slices(prediction_input)
dataset = dataset.map(decode)
dataset = dataset.batch(1)
iterator = dataset.make_one_shot_iterator()
next_feature_batch = iterator.get_next()
return next_feature_batch, None # In prediction, we have no labels  

# Data sets
def train_test_input_fn(dateipfad, mit_shuffle=False, anzahl_wiederholungen=1):
def parser(zeile):
    parsed_line = tf.decode_csv(zeile, [[''], [0], [''], [''], [0], [''], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0]], field_delim=' ')
    label = parsed_line[-1:] # Last element is the label
    del parsed_line[-1] # Delete last element
    features = parsed_line # Everything (but last element) are the features
    d = dict(zip(feature_names, features)), label
    return d

dataset = tf.data.TextLineDataset(dateipfad)
dataset = dataset.map(parser)
if mit_shuffle:
    dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(1)
dataset = dataset.repeat(anzahl_wiederholungen)
iterator = dataset.make_one_shot_iterator()

# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels

def main():
feature_columns = [tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute1', ['A11', 'A12', 'A13', 'A14'])),
                    tf.feature_column.numeric_column('Attribute2', shape=[1]),
                            tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute3', ['A30', 'A31', 'A32', 'A33'])),
                        tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute4', ['A40', 'A41', 'A42', 'A43', 'A44', 'A45', 'A46', 'A47', 'A48', 'A49', 'A410'])),
                    tf.feature_column.numeric_column('Attribute5', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute6', ['A61', 'A62', 'A63', 'A64', 'A65'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute7', ['A71', 'A72', 'A73', 'A74', 'A75'])),
                    tf.feature_column.numeric_column('Attribute8', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute9', ['A91', 'A92', 'A93', 'A94', 'A95'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute10', ['A101', 'A102', 'A103'])),
                    tf.feature_column.numeric_column('Attribute11', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute12', ['A121', 'A122', 'A123', 'A124'])),
                    tf.feature_column.numeric_column('Attribute13', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute14', ['A141', 'A142', 'A143'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute15', ['A151', 'A152', 'A153'])),
                    tf.feature_column.numeric_column('Attribute16', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute17', ['A171', 'A172', 'A173', 'A174'])),
                    tf.feature_column.numeric_column('Attribute18', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute19', ['A191', 'A192'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute20', ['A201', 'A202']))]

classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                          hidden_units=[100],
                                          n_classes=2,
                                          model_dir="./summaries")                                                    

# Trainieren des Models
classifier.train(input_fn=lambda: train_test_input_fn("german.data.train.txt", True, 10))

# Errechne die Genauigkeit ("accuracy").
accuracy_score = classifier.evaluate(input_fn=lambda: train_test_input_fn("german.data.test.txt", False, 4))["accuracy"]
print("\nTest Genauigkeit: {0:f}\n".format(accuracy_score))

feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
classifier.export_savedmodel("./export" , serving_input_receiver_fn, as_text=True)

predict_results = classifier.predict(input_fn=predict_input_fn)
for idx, prediction in enumerate(predict_results):
   type = prediction["class_ids"][0] # Get the predicted class (index)
   if type == 0:
       print("Ich denke: {}, ist nicht kreditwürdig".format(prediction_input[idx]))
   elif type == 1:
       print("Ich denke: {}, ist kreditwürdig".format(prediction_input[idx]))

if __name__ == "__main__":
main()

但我什么也没找到,我如何才能在 Java 客户端中提供这样的分类列?你能提供一个样品我怎么做这个吗?

我目前的状态是这样的,但不知道我必须创建哪个 Tensor 来预测 Java 中的训练模型:

public static void main(String[] args) throws Exception {
    String pfad = System.getProperty("user.dir") + "\\1511523781";
    Session session = SavedModelBundle.load(pfad, "serve").session();
    String example = "A12 12 A32 A40 7472 A65 A71 1 A92 A101 2 A121 24 A143 A151 1 A171 1 A191 A201";

    final String xName = "input_example_tensor";
    final String scoresName = "dnn/head/predictions/probabilities:0";

    List<Tensor<?>> outputs = session.runner()
        .feed(xName, example)
        .fetch(scoresName)
        .run();

    // Outer dimension is batch size; inner dimension is number of classes
    float[][] scores = new float[2][3];
    outputs.get(0).copyTo(scores);
    System.out.println(Arrays.deepToString(scores));
  }

谢谢!

【问题讨论】:

    标签: java python tensorflow tensorflow-serving


    【解决方案1】:

    由于您使用的是tf.estimator.export.build_parsing_serving_input_receiver_fn,因此您创建的导出保存模型需要序列化的tf.Example protocol buffer 作为输入。

    您可以使用 Java 中的 tf.Example 协议缓冲区(mavenjavadoc),使用如下方式:

    import com.google.protobuf.ByteString;
    import java.util.Arrays;
    import org.tensorflow.*;
    import org.tensorflow.example.*;
    
    public class Main {
      // Returns a Feature containing a BytesList, where each element of the list
      // is the UTF-8 encoded bytes of the Java string.
      public static Feature feature(String... strings) {
        BytesList.Builder b = BytesList.newBuilder();
        for (String s : strings) {
          b.addValue(ByteString.copyFromUtf8(s));
        }
        return Feature.newBuilder().setBytesList(b).build();
      }
    
      public static Feature feature(float... values) {
        FloatList.Builder b = FloatList.newBuilder();
        for (float v : values) {
          b.addValue(v);
        }
        return Feature.newBuilder().setFloatList(b).build();
      }
    
      public static void main(String[] args) throws Exception {
        Features features =
            Features.newBuilder()
                .putFeature("Attribute1", feature("A12"))
                .putFeature("Attribute2", feature(12))
                .putFeature("Attribute3", feature("A32"))
                .putFeature("Attribute4", feature("A40"))
                .putFeature("Attribute5", feature(7472))
                .putFeature("Attribute6", feature("A65"))
                .putFeature("Attribute7", feature("A71"))
                .putFeature("Attribute8", feature(1))
                .putFeature("Attribute9", feature("A92"))
                .putFeature("Attribute10", feature("A101"))
                .putFeature("Attribute11", feature(2))
                .putFeature("Attribute12", feature("A121"))
                .putFeature("Attribute13", feature(24))
                .putFeature("Attribute14", feature("A143"))
                .putFeature("Attribute15", feature("A151"))
                .putFeature("Attribute16", feature(1))
                .putFeature("Attribute17", feature("A171"))
                .putFeature("Attribute18", feature(1))
                .putFeature("Attribute19", feature("A191"))
                .putFeature("Attribute20", feature("A201"))
                .build();
        Example example = Example.newBuilder().setFeatures(features).build();
    
        String pfad = System.getProperty("user.dir") + "\\1511523781";
        try (SavedModelBundle model = SavedModelBundle.load(pfad, "serve")) {
          Session session = model.session();
          final String xName = "input_example_tensor";
          final String scoresName = "dnn/head/predictions/probabilities:0";
    
          try (Tensor<String> inputBatch = Tensors.create(new byte[][] {example.toByteArray()});
              Tensor<Float> output =
                  session
                      .runner()
                      .feed(xName, inputBatch)
                      .fetch(scoresName)
                      .run()
                      .get(0)
                      .expect(Float.class)) {
            System.out.println(Arrays.deepToString(output.copyTo(new float[1][2])));
          }
        }
      }
    }
    

    这里的大部分样板是构建协议缓冲区示例。或者,您可以使用 build_arsing_serving_input_receiver_fn 以外的其他东西来设置导出的模型以接受不同格式的输入。

    旁注:您可以使用 TensorFlow Python 安装中包含的 saved_model_cli 命令行工具来检查保存的模型。例如,类似:

    saved_model_cli show  \
      --dir ./export/1511523781 \
      --tag_set serve \
      --signature_def predict
    

    将显示如下内容:

    The given SavedModel SignatureDef contains the following input(s):
    inputs['examples'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: input_example_tensor:0
    The given SavedModel SignatureDef contains the following output(s):
    ...
    outputs['probabilities'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 2)
        name: dnn/head/predictions/probabilities:0
    

    建议保存的模型采用单个输入 - 一批DT_STRING 元素,输出概率是一批二维浮点向量。

    希望对您有所帮助。

    【讨论】:

    • 知道如何更改它,它也适用于空字符串特征值吗?如果一个测试行具有空字符串特征,则预测返回与原始 python 模型完全不同的概率。是否可能,因为以字节和空字符串解析字符串没有字节表示?谢谢。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-10-12
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多