【问题标题】:dl4j lstm not successfuldl4j lstm 不成功
【发布时间】:2020-06-22 22:12:30
【问题描述】:

我试图在此链接的页面中间复制练习: https://d2l.ai/chapter_recurrent-neural-networks/sequence.html

练习使用正弦函数在 -1 到 1 之间创建 1000 个数据点,并使用循环网络来逼近该函数。

下面是我使用的代码。我会回去研究更多为什么这不起作用,因为现在我可以轻松地使用前馈网络来近似这个函数,这对我来说没有多大意义。

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

你能解释一下我需要一个 1 分 10 隐藏和 1 出 lstm 网络来逼近正弦函数的代码吗?

我没有使用任何归一化,因为函数已经是 -1:1 并且我使用 Y 输入作为特征,然后使用以下 Y 输入作为标签来训练网络。

您注意到我正在构建一个可以更轻松地构建网络的类,并且我尝试对问题进行了许多更改,但我厌倦了猜测。

以下是我的结果的一些示例。蓝色是数据,红色是结果

【问题讨论】:

    标签: java deep-learning dl4j nd4j


    【解决方案1】:

    这是你从想知道为什么这不起作用到我原来的结果怎么和他们一样好的时候之一。

    我的失败在于没有清楚地理解文档,也没有理解 BPTT。

    对于前馈网络,每次迭代都存储为一行,每个输入存储为一列。一个例子是[dataset.size, network inputs.size]

    但是,对于循环输入,它的反转是每行是一个输入,每列是一次迭代,这是激活 lstm 事件链状态所必需的时间。至少我的输入需要是 [0, networkinputs.size, dataset.size] 但也可以是 [dataset.size, networkinputs.size, statelength.size]

    在我之前的示例中,我使用这种格式的数据训练网络 [dataset.size, networkinputs.size, 1]。因此,根据我对低分辨率的理解,lstm 网络根本不应该工作,但至少以某种方式产生了一些东西。

    将数据集转换为列表也可能存在一些问题,因为我也更改了为网络提供数据的方式,但我认为大部分问题是数据结构问题。

    以下是我的新结果

    【讨论】:

      【解决方案2】:

      如果没有看到完整的代码,很难知道发生了什么。首先,我没有看到指定的 RnnOutputLayer。您可以查看this,它向您展示了如何在 DL4J 中构建 RNN。 如果您的 RNN 设置正确,这可能是一个调优问题。你可以找到更多关于调整here。对于更新程序,Adam 可能是比 RMSProp 更好的选择。 tanh 可能是激活输出层的不错选择,因为它的范围是 (-1,1)。需要检查/调整的其他事项 - 学习率、时期数、数据设置(例如,您是否试图预测很远?)。

      【讨论】:

      • 非常感谢您的回答。我的代码被混淆了,因为我正在构建一个自动和迭代的网络生成器。我同意您的解决方案会产生更好的结果,但我相信这也应该有效。我试图理解为什么这不起作用,所以我更多的是策略而不是炼金术。
      猜你喜欢
      • 2021-12-12
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-03-02
      • 2020-05-13
      • 1970-01-01
      • 2021-12-23
      相关资源
      最近更新 更多