【问题标题】:MultiLayerNetwork to predict simple functionMultiLayerNetwork 预测简单函数
【发布时间】:2016-03-29 09:52:11
【问题描述】:

我正在尝试培养一些机器学习的直觉。我查看了来自https://github.com/deeplearning4j/dl4j-0.4-examples 的示例,我想开发自己的示例。基本上我只是采用了一个简单的函数:a * a + b * b + c * c - a * b * c + a + b + c 并为随机 a、b、c 生成 10000 个输出,并尝试在 90 上训练我的网络% 的输入。问题是无论我做什么,我的网络都无法预测其余的示例。

这是我的代码:

public class BasicFunctionNN {

    private static Logger log = LoggerFactory.getLogger(MlPredict.class);

    public static DataSetIterator generateFunctionDataSet() {
        Collection<DataSet> list = new ArrayList<>();
        for (int i = 0; i < 100000; i++) {
            double a = Math.random();
            double b = Math.random();
            double c = Math.random();

            double output = a * a + b * b + c * c - a * b * c + a + b + c;
            INDArray in = Nd4j.create(new double[]{a, b, c});
            INDArray out = Nd4j.create(new double[]{output});
            list.add(new DataSet(in, out));
        }
        return new ListDataSetIterator(list, list.size());
    }

    public static void main(String[] args) throws Exception {
        DataSetIterator iterator = generateFunctionDataSet();

        Nd4j.MAX_SLICES_TO_PRINT = 10;
        Nd4j.MAX_ELEMENTS_PER_SLICE = 10;

        final int numInputs = 3;
        int outputNum = 1;
        int iterations = 100;

        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .iterations(iterations).weightInit(WeightInit.XAVIER).updater(Updater.SGD).dropOut(0.5)
                .learningRate(.8).regularization(true)
                .l1(1e-1).l2(2e-4)
                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
                .list(3)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(8)
                        .activation("identity")
                        .build())
                .layer(1, new DenseLayer.Builder().nIn(8).nOut(8)
                        .activation("identity")
                        .build())
                .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.RMSE_XENT)//LossFunctions.LossFunction.RMSE_XENT)
                        .activation("identity")
                        .weightInit(WeightInit.XAVIER)
                        .nIn(8).nOut(outputNum).build())
                .backprop(true).pretrain(false)
                .build();


        //run the model
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(iterations)));

        //get the dataset using the record reader. The datasetiterator handles vectorization
        DataSet next = iterator.next();
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(0.9);
        System.out.println(testAndTrain.getTrain());

        model.fit(testAndTrain.getTrain());

        //evaluate the model
        Evaluation eval = new Evaluation(10);
        DataSet test = testAndTrain.getTest();
        INDArray output = model.output(test.getFeatureMatrix());
        eval.eval(test.getLabels(), output);
        log.info(">>>>>>>>>>>>>>");
        log.info(eval.stats());

    }
}

我也玩过学习率,多次出现分数没有提高的情况:

10:48:51.404 [main] DEBUG o.d.o.solvers.BackTrackLineSearch - Exited line search after maxIterations termination condition; score did not improve (bestScore=0.8522868127536543, scoreAtStart=0.8522868127536543). Resetting parameters

作为激活函数,我也尝试了 relu

【问题讨论】:

    标签: machine-learning deep-learning deeplearning4j


    【解决方案1】:

    一个明显的问题是您试图用线性模型对非线性函数进行建模。您的神经网络没有激活函数,因此它只能有效地表达 W1a + W2b + W3c + W4 形式的函数。您创建多少隐藏单元无关紧要 - 只要不使用非线性激活函数,您的网络就会退化为简单的线性模型。

    更新

    还有很多“小怪事”,包括但不限于:

    • 您正在使用巨大的学习率 (0.8)
    • 您正在使用大量正则化(相当复杂,使用 l1 和 l2 正则化器进行回归不是一种常见的方法,尤其是在神经网络中)一个您不需要的问题
    • 整流器单元可能不是表达平方运算以及您正在寻找的乘法运算的最佳单元。整流器非常适合分类,特别是对于更深的架构,但不适用于浅层回归。改用 sigmoid-alike (tanh, sigmoid) 激活。
    • 我不完全确定“迭代”在此实现中的含义,但通常这是用于训练的样本/小批量的数量。因此,对于梯度下降学习而言,仅使用 100 个数量级可能太小了

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2014-01-31
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-04-09
      • 2020-11-21
      • 2022-01-14
      相关资源
      最近更新 更多