【问题标题】:Regress the max function over a neural network在神经网络上回归最大函数
【发布时间】:2016-02-25 10:00:41
【问题描述】:

我正在训练自己学习神经网络。有一个函数我无法让我的神经网络学习:f(x) = max(x_1, x_2)。这似乎是一个非常简单的函数,有 2 个输入和 1 个输入,但是一个 3 层神经网络训练了超过 2000 个 epoch 的一千个样本,却完全错误。我正在使用deeplearning4j

对于神经网络来说,max 函数很难学习,还是我只是调错了?

【问题讨论】:

    标签: neural-network deep-linking deeplearning4j


    【解决方案1】:

    只是想指出:如果您使用 relu 而不是 tanh 实际上有一个精确的解决方案,我猜您是否会将网络缩小到完全相同的大小(1 个隐藏层和 3 个节点),你总是会得到这些权重(节点的模块排列和权重的缩放(第一层按 gamma 缩放,第二层按 1/gamma 缩放)):

    max(a,b) = ((1, 1, -1)) * relu( ((1,-1), (0,1), (0,-1)) * ((a,b)) )
    

    其中* 是矩阵乘法。

    这个等式将以下人类可读的版本翻译成 NN 语言:

    max(a,b) = relu(a-b) + b = relu(a-b) + relu(b) - relu(-b)
    

    我还没有实际测试过,我的观点是,理论上它应该对于网络来说很容易学习这个功能。

    编辑: 我刚刚对此进行了测试,结果与我预期的一样:

    [[-1.0714666e+00 -7.9943770e-01  9.0549403e-01]
     [ 1.0714666e+00 -7.7552663e-08  2.6146751e-08]]
    

    [[ 0.93330014]
     [-1.250879  ]
     [ 1.1043695 ]]
    

    这里对应第一层和第二层。转置第二个并与第一组权重相乘,最终得到一个标准化版本,可以很容易地与我的理论结果进行比较:

    [[-9.9999988e-01  9.9999988e-01  1.0000000e+00]
     [ 9.9999988e-01  9.7009000e-08  2.8875675e-08]]
    

    【讨论】:

      【解决方案2】:

      至少,如果您将 x1 和 x2 限制在一个区间内,这并不难,例如[0,3] 之间。以 deeplearning4j 示例中的“RegressionSum”示例为例,我很快将其重写为学习 max 而不是 sum,它可以很好地给我这样的结果:

      Max(0.6815540048808918,0.3112081053899819) = 0.64
      Max(2.0073597506364407,1.93796211086664) = 2.09
      Max(1.1792029272560556,2.5514324329058233) = 2.58
      Max(2.489185375059013,0.0818746888836388) = 2.46
      Max(2.658169689797984,1.419135581889197) = 2.66
      Max(2.855509810112818,2.9661811672685086) = 2.98
      Max(2.774757710538552,1.3988513143140069) = 2.79
      Max(1.5852295273047565,1.1228662895771744) = 1.56
      Max(0.8403435207065576,2.5595015474951195) = 2.60
      Max(0.06913178775631723,2.61883825802004) = 2.54
      

      以下是我修改后的 RegressionSum 示例,最初来自 Anwar 3/15/16:

      public class RegressionMax {
          //Random number generator seed, for reproducability
          public static final int seed = 12345;
          //Number of iterations per minibatch
          public static final int iterations = 1;
          //Number of epochs (full passes of the data)
          public static final int nEpochs = 200;
          //Number of data points
          public static final int nSamples = 10000;
          //Batch size: i.e., each epoch has nSamples/batchSize parameter updates
          public static final int batchSize = 100;
          //Network learning rate
          public static final double learningRate = 0.01;
          // The range of the sample data, data in range (0-1 is sensitive for NN, you can try other ranges and see how it effects the results
          // also try changing the range along with changing the activation function
          public static int MIN_RANGE = 0;
          public static int MAX_RANGE = 3;
      
          public static final Random rng = new Random(seed);
      
          public static void main(String[] args){
      
              //Generate the training data
              DataSetIterator iterator = getTrainingData(batchSize,rng);
      
              //Create the network
              int numInput = 2;
              int numOutputs = 1;
              int nHidden = 10;
              MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
                      .seed(seed)
                      .iterations(iterations)
                      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                      .learningRate(learningRate)
                      .weightInit(WeightInit.XAVIER)
                      .updater(Updater.NESTEROVS).momentum(0.9)
                      .list()
                      .layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden)
                              .activation("tanh")
                              .build())
                      .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                              .activation("identity")
                              .nIn(nHidden).nOut(numOutputs).build())
                      .pretrain(false).backprop(true).build()
              );
              net.init();
              net.setListeners(new ScoreIterationListener(1));
      
      
              //Train the network on the full data set, and evaluate in periodically
              for( int i=0; i<nEpochs; i++ ){
                  iterator.reset();
                  net.fit(iterator);
              }
      
              // Test the max of some numbers (Try different numbers here)
              Random rand = new Random();
              for (int i= 0; i< 10; i++) {
                  double d1 = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
                  double d2 =  MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
                  INDArray input = Nd4j.create(new double[] { d1, d2 }, new int[] { 1, 2 });
                  INDArray out = net.output(input, false);
                  System.out.println("Max(" + d1 + "," + d2 + ") = " + out);
              }
      
          }
      
          private static DataSetIterator getTrainingData(int batchSize, Random rand){
              double [] max = new double[nSamples];
              double [] input1 = new double[nSamples];
              double [] input2 = new double[nSamples];
              for (int i= 0; i< nSamples; i++) {
                  input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
                  input2[i] =  MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
                  max[i] = Math.max(input1[i], input2[i]);
              }
              INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1});
              INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1});
              INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2);
              INDArray outPut = Nd4j.create(max, new int[]{nSamples, 1});
              DataSet dataSet = new DataSet(inputNDArray, outPut);
              List<DataSet> listDs = dataSet.asList();
              Collections.shuffle(listDs,rng);
              return new ListDataSetIterator(listDs,batchSize);
      
          }
      }
      

      【讨论】:

        猜你喜欢
        • 2019-01-27
        • 2021-09-13
        • 1970-01-01
        • 2021-05-20
        • 2014-02-19
        • 2020-10-24
        • 2020-05-12
        • 2018-04-30
        • 2019-08-13
        相关资源
        最近更新 更多