【问题标题】:Spark|ML|Random Forest|Load trained model from .txt of RandomForestClassificationModel. toDebugStringSpark|ML|随机森林|从 RandomForestClassificationModel 的 .txt 加载训练模型。 toDebugString
【发布时间】:2017-05-01 20:34:21
【问题描述】:

使用 Spark 1.6 和 ML 库,我正在使用 toDebugString() 保存经过训练的 RandomForestClassificationModel 的结果:

 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
 val stringModel =rfModel.toDebugString
 //save stringModel into a file in the driver in format .txt 

所以我的想法是,以后读取文件.txt,加载训练好的randomForest,有可能吗?

谢谢!

【问题讨论】:

    标签: apache-spark serialization random-forest apache-spark-ml


    【解决方案1】:

    那行不通。 ToDebugString 只是一个调试信息,用于了解它是如何计算的。

    如果你想保留这个东西以备后用,你可以做和我们一样的事情,也就是(虽然我们是在纯 java 中)简单地序列化 RandomForestModel 对象。默认java序列化可能存在版本不兼容,所以我们使用Hessian来做。它通过版本更新工作 - 我们从 spark 1.6.1 开始,它仍然适用于 spark 2.0.2。

    【讨论】:

      【解决方案2】:

      如果您可以不坚持使用 ml,请使用 mllib 的实现:您使用 mllib 获得的 RandomForestModel 具有 save 函数。

      【讨论】:

        【解决方案3】:

        至少对于 Spark 2.1.0,您可以使用以下 Java(抱歉 - 没有 Scala)代码来做到这一点。但是,依赖可能会在没有通知的情况下更改的未记录格式可能不是最明智的想法。

        import org.slf4j.Logger;
        import org.slf4j.LoggerFactory;
        
        import java.io.*;
        import java.net.URL;
        import java.util.*;
        import java.util.function.Predicate;
        import java.util.regex.Matcher;
        import java.util.regex.Pattern;
        
        import static java.nio.charset.StandardCharsets.US_ASCII;
        
        /**
         * RandomForest.
         */
        public abstract class RandomForest {
        
            private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);
        
            protected final List<Node> trees = new ArrayList<>();
        
            /**
             * @param model model file (format is Spark's RandomForestClassificationModel toDebugString())
             * @throws IOException
             */
            public RandomForest(final URL model) throws IOException {
                try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) {
                    Node node;
                    while ((node = load(reader)) != null) {
                        trees.add(node);
                    }
                }
                if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
                if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
            }
        
            private static Node load(final BufferedReader reader) throws IOException {
                final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
                final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
                Node root = null;
                final List<Node> stack = new ArrayList<>();
                String line;
                while ((line = reader.readLine()) != null) {
                    final String trimmed = line.trim();
                    //System.out.println(trimmed);
                    if (trimmed.startsWith("RandomForest")) {
                        // skip the "Tree 1" line
                        reader.readLine();
                    } else if (trimmed.startsWith("Tree")) {
                        break;
                    } else if (trimmed.startsWith("If")) {
                        // extract feature index
                        final Matcher m = ifPattern.matcher(trimmed);
                        m.matches();
                        final int featureIndex = Integer.parseInt(m.group(1));
                        final String operator = m.group(2);
                        final String operand = m.group(3);
                        final Predicate<Float> predicate;
                        if ("<=".equals(operator)) {
                            predicate = new LessOrEqual(Float.parseFloat(operand));
                        } else if (">".equals(operator)) {
                            predicate = new Greater(Float.parseFloat(operand));
                        } else if ("in".equals(operator)) {
                            predicate = new In(parseFloatArray(operand));
                        } else if ("not in".equals(operator)) {
                            predicate = new NotIn(parseFloatArray(operand));
                        } else {
                            predicate = null;
                        }
                        final Node node = new Node(featureIndex, predicate);
        
                        if (stack.isEmpty()) {
                            root = node;
                        } else {
                            insert(stack, node);
                        }
                        stack.add(node);
                    } else if (trimmed.startsWith("Predict")) {
                        final Matcher m = predictPattern.matcher(trimmed);
                        m.matches();
                        final Object node = Float.parseFloat(m.group(1));
                        insert(stack, node);
                    }
                }
                return root;
            }
        
            private static void insert(final List<Node> stack, final Object node) {
                Node parent = stack.get(stack.size() - 1);
                while (parent.getLeftChild() != null && parent.getRightChild() != null) {
                    stack.remove(stack.size() - 1);
                    parent = stack.get(stack.size() - 1);
                }
                if (parent.getLeftChild() == null) parent.setLeftChild(node);
                else parent.setRightChild(node);
            }
        
            private static float[] parseFloatArray(final String set) {
                final StringTokenizer st = new StringTokenizer(set, "{,}");
                final float[] floats = new float[st.countTokens()];
                for (int i=0; st.hasMoreTokens(); i++) {
                    floats[i] = Float.parseFloat(st.nextToken());
                }
                return floats;
            }
        
            public abstract float predict(final float[] features);
        
            public String toDebugString() {
                try {
                    final StringWriter sw = new StringWriter();
                    for (int i=0; i<trees.size(); i++) {
                        sw.write("Tree " + i + ":\n");
                        print(sw, "", trees.get(0));
                    }
                    return sw.toString();
                } catch (IOException e) {
                    throw new UncheckedIOException(e);
                }
            }
        
            private static void print(final Writer w, final String indent, final Object object) throws IOException {
                if (object instanceof Number) {
                    w.write(indent + "Predict: " + object + "\n");
                } else if (object instanceof Node) {
                    final Node node = (Node) object;
                    // left node
                    w.write(indent + node + "\n");
                    print(w, indent + " ", node.getLeftChild());
                    w.write(indent + "Else\n");
                    print(w, indent + " ", node.getRightChild());
                }
            }
        
            @Override
            public String toString() {
                return getClass().getSimpleName() + "{numTrees=" + trees.size() + "}";
            }
        
            /**
             * Node.
             */
            protected static class Node {
        
                private final int featureIndex;
                private final Predicate<Float> predicate;
                private Object leftChild;
                private Object rightChild;
        
                public Node(final int featureIndex, final Predicate<Float> predicate) {
                    Objects.requireNonNull(predicate);
                    this.featureIndex = featureIndex;
                    this.predicate = predicate;
                }
        
                public void setLeftChild(final Object leftChild) {
                    this.leftChild = leftChild;
                }
        
                public void setRightChild(final Object rightChild) {
                    this.rightChild = rightChild;
                }
        
                public Object getLeftChild() {
                    return leftChild;
                }
        
                public Object getRightChild() {
                    return rightChild;
                }
        
                public Object eval(final float[] features) {
                    Object result = this;
                    do {
                        final Node node = (Node)result;
                        result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
                    } while (result instanceof Node);
        
                    return result;
                }
        
                @Override
                public String toString() {
                    return "If (feature " + featureIndex + " " + predicate + ")";
                }
        
            }
        
            private static class LessOrEqual implements Predicate<Float> {
                private final float value;
        
                public LessOrEqual(final float value) {
                    this.value = value;
                }
        
                @Override
                public boolean test(final Float f) {
                    return f <= value;
                }
        
                @Override
                public String toString() {
                    return "<= " + value;
                }
            }
        
            private static class Greater implements Predicate<Float> {
                private final float value;
        
                public Greater(final float value) {
                    this.value = value;
                }
        
                @Override
                public boolean test(final Float f) {
                    return f > value;
                }
        
                @Override
                public String toString() {
                    return "> " + value;
                }
            }
        
            private static class In implements Predicate<Float> {
                private final float[] array;
        
                public In(final float[] array) {
                    this.array = array;
                }
        
                @Override
                public boolean test(final Float f) {
                    for (int i=0; i<array.length; i++) {
                        if (array[i] == f) return true;
                    }
                    return false;
                }
        
                @Override
                public String toString() {
                    return "in " + Arrays.toString(array);
                }
            }
        
            private static class NotIn implements Predicate<Float> {
                private final float[] array;
        
                public NotIn(final float[] array) {
                    this.array = array;
                }
        
                @Override
                public boolean test(final Float f) {
                    for (int i=0; i<array.length; i++) {
                        if (array[i] == f) return false;
                    }
                    return true;
                }
        
                @Override
                public String toString() {
                    return "not in " + Arrays.toString(array);
                }
            }
        }
        

        要使用类进行分类,请使用:

        import java.io.IOException;
        import java.net.URL;
        import java.util.HashMap;
        import java.util.Map;
        
        /**
         * RandomForestClassifier.
         */
        public class RandomForestClassifier extends RandomForest {
        
            public RandomForestClassifier(final URL model) throws IOException {
                super(model);
            }
        
            @Override
            public float predict(final float[] features) {
                final Map<Object, Integer> counts = new HashMap<>();
                trees.stream().map(node -> node.eval(features))
                        .forEach(result -> {
                            Integer count = counts.get(result);
                            if (count == null) {
                                counts.put(result, 1);
                            } else {
                                counts.put(result, count + 1);
                            }
                        });
                return (Float)counts.entrySet()
                        .stream()
                        .sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
                        .map(Map.Entry::getKey)
                        .findFirst().get();
            }
        }
        

        对于回归:

        import java.io.IOException;
        import java.net.URL;
        
        /**
         * RandomForestRegressor.
         */
        public class RandomForestRegressor extends RandomForest {
        
            public RandomForestRegressor(final URL model) throws IOException {
                super(model);
            }
        
            @Override
            public float predict(final float[] features) {
                return (float)trees
                        .stream()
                        .mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
                        .average()
                        .getAsDouble();
            }
        }
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2020-11-30
          • 2018-12-06
          • 2017-01-22
          • 2017-10-19
          • 2018-06-13
          • 2012-10-25
          • 1970-01-01
          相关资源
          最近更新 更多