【发布时间】:2021-06-24 20:26:14
【问题描述】:
我将double[][] 类型的大型(1M 样本 * 100 个特征)数组加载到内存中,我需要有效地创建DataSetIterator(随机批次)以提供MultiLayerNetwork。我该怎么做?
我发现的大多数 deeplearning4j 示例都侧重于从文件中加载数据,但是当数据已经在内存中时该怎么办?
【问题讨论】:
标签: deeplearning4j
我将double[][] 类型的大型(1M 样本 * 100 个特征)数组加载到内存中,我需要有效地创建DataSetIterator(随机批次)以提供MultiLayerNetwork。我该怎么做?
我发现的大多数 deeplearning4j 示例都侧重于从文件中加载数据,但是当数据已经在内存中时该怎么办?
【问题讨论】:
标签: deeplearning4j
DataSetIterator 只是一种创建小批量的方法。小批量数据集是一个数据集对象,它本身只是 2 个 INDArray(特征和标签) 如果您的数据已经在内存中,一个简单的:
double[][] data = ...;
INDArray arr = Nd4j.create(data);
请注意,当您这样做时,这会将一堆数据移出堆并分配大量额外内存。我们不使用普通 java 数据结构的一个重要原因是(虽然易于使用)由于性能原因,我们所有的计算都发生在 c++ 中。
用于创建您将使用的数据集
DataSet d = new DataSet(arr,arr);
请注意,构造函数中的第二个 arr 实际上应该是一个单独的第二个数组,代表您的数据集。
对于迭代器(它只是在幕后创建这些),您可以使用以下内容: https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/ListDataSetIterator.java
与:
DataSetIterator listIterator = new ViewIterator(data,5);
注意第二个参数是你的批量大小。
【讨论】: