【问题标题】:Problems with high/low-endian integers while reading the MNIST dataset in java在 Java 中读取 MNIST 数据集时出现高/低端整数问题
【发布时间】:2018-02-17 06:49:10
【问题描述】:

我一直在尝试读取 MNIST 数据集,以便能够对其进行格式化以用于神经网络。但是我在尝试从高端到低端的转换工作时遇到了麻烦。

当我读取数据时,输出的第一个整数是529205256,当转换为低端格式时是134777631,仍然远高于2051 的预期“幻数”。

无论我尝试哪种解决方案,我都会得到相同的错误号码,所以如果有人能向我解释我的错误,我将不胜感激。

部分代码借鉴自github。

这是我的代码中发生错误的部分:

public static List<int[][]> getImages(String infile) {
    ByteBuffer bb = loadFileToByteBuffer(infile);

    assertMagicNumber(IMAGE_FILE_MAGIC_NUMBER, bb.getInt());
    int numImages = bb.getInt();
    int numRows = bb.getInt();
    int numColumns = bb.getInt();

    List<int[][]> images = new ArrayList<>();

    for (int i = 0; i < numImages; i++)
        images.add(readImage(numRows, numColumns, bb));

    return images;
}

调用 bb.getInt() 时,它返回整数 529205256,即使在使用这段代码转换后也是如此

public static int swap(int value)
  {
    int b1 = (value >>  0) & 0xff;
    int b2 = (value >>  8) & 0xff;
    int b3 = (value >> 16) & 0xff;
    int b4 = (value >> 24) & 0xff;

    return b1 << 24 | b2 << 16 | b3 << 8 | b4 << 0;
  }

仍然不会产生正确的数字,因此 assertMagicNumber 会抛出异常,因为值不相等。

如有必要,这是课程的其余部分:

package core;

import static java.lang.String.format;


import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;

public class MnistReader {
public static final int LABEL_FILE_MAGIC_NUMBER = 2049;
public static final int IMAGE_FILE_MAGIC_NUMBER = 2051;

public static int[] getLabels(String infile) {

    ByteBuffer bb = loadFileToByteBuffer(infile);

    assertMagicNumber(LABEL_FILE_MAGIC_NUMBER, bb.getInt());

    int numLabels = bb.getInt();
    int[] labels = new int[numLabels];

    for (int i = 0; i < numLabels; ++i)
        labels[i] = bb.get() & 0xFF; // To unsigned

    return labels;
}

public static List<int[][]> getImages(String infile) {
    ByteBuffer bb = loadFileToByteBuffer(infile);

    assertMagicNumber(IMAGE_FILE_MAGIC_NUMBER, bb.getInt());
    int numImages = bb.getInt();
    int numRows = bb.getInt();
    int numColumns = bb.getInt();

    List<int[][]> images = new ArrayList<>();

    for (int i = 0; i < numImages; i++)
        images.add(readImage(numRows, numColumns, bb));

    return images;
}

private static int[][] readImage(int numRows, int numCols, ByteBuffer bb) {
    int[][] image = new int[numRows][];
    for (int row = 0; row < numRows; row++)
        image[row] = readRow(numCols, bb);
    return image;
}

private static int[] readRow(int numCols, ByteBuffer bb) {
    int[] row = new int[numCols];
    for (int col = 0; col < numCols; ++col)
        row[col] = bb.get() & 0xFF; // To unsigned
    return row;
}

public static void assertMagicNumber(int expectedMagicNumber, int magicNumber) {

    System.out.println(expectedMagicNumber);
    System.out.println(magicNumber);

    if (expectedMagicNumber != magicNumber) {
        switch (expectedMagicNumber) {
        case LABEL_FILE_MAGIC_NUMBER:
            throw new RuntimeException("This is not a label file.");
        case IMAGE_FILE_MAGIC_NUMBER:
            throw new RuntimeException("This is not an image file.");
        default:
            throw new RuntimeException(
                    format("Expected magic number %d, found %d", expectedMagicNumber, magicNumber));
        }
    }
}
//
//
//
//

public static ByteBuffer loadFileToByteBuffer(String infile) {
    return ByteBuffer.wrap(loadFile(infile));
}

public static byte[] loadFile(String infile) {
    try {
        RandomAccessFile f = new RandomAccessFile(infile, "r");
        FileChannel chan = f.getChannel();
        long fileSize = chan.size();
        ByteBuffer bb = ByteBuffer.allocate((int) fileSize);
        chan.read(bb);
        bb.flip();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        for (int i = 0; i < fileSize; i++)
            baos.write(bb.get());
        chan.close();
        f.close();
        return baos.toByteArray();
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}



public static String renderImage(int[][] image) {
    StringBuffer sb = new StringBuffer();

    for (int row = 0; row < image.length; row++) {
        sb.append("|");
        for (int col = 0; col < image[row].length; col++) {
            int pixelVal = image[row][col];
            if (pixelVal == 0)
                sb.append(" ");
            else if (pixelVal < 256 / 3)
                sb.append(".");
            else if (pixelVal < 2 * (256 / 3))
                sb.append("x");
            else
                sb.append("X");
        }
        sb.append("|\n");
    }

    return sb.toString();
}

public static String repeat(String s, int n) {
    StringBuilder sb = new StringBuilder();
    for (int i = 0; i < n; i++)
        sb.append(s);
    return sb.toString();
}

/* (Added method)
 * converts the image data from a 2-dimensional to a 1-dimensional array
 * and compresses the pixel values between 0 and 1
 */

public static double[] convertImage(int[][] source) {

    double[] convertedImage = new double[784];
    int currentPos = 0;
    for(int i = 0; i < source.length; i++) {
        for(int j = 0; j < source[i].length; j++) {
            convertedImage[currentPos] = source[i][j] / 255;
        }
    }
    return convertedImage;
}

/* (Added method)
 * converts the label data from an Integer to a vector that can be used
 * as output for the neural network
 */

public static double[] convertLabel(int label) {

    double[] convertedLabel = new double[10];
    convertedLabel[label] = 1;
    return convertedLabel;
}

public static int swap(int value)
  {
    int b1 = (value >>  0) & 0xff;
    int b2 = (value >>  8) & 0xff;
    int b3 = (value >> 16) & 0xff;
    int b4 = (value >> 24) & 0xff;

    return b1 << 24 | b2 << 16 | b3 << 8 | b4 << 0;
  }

}

我真的不知道我的错误在哪里,所以我们将不胜感激。

MNIST 的链接:http://yann.lecun.com/exdb/mnist/

编辑:原来解压缩文件本身存在问题。修复后,一切都开始按预期工作

【问题讨论】:

  • 我建议您以十六进制而不是十进制输出您的数字。训练您的思维以二进制方式思考,因为这就是您处理数据的方式。
  • 不需要转换。 MNIS 数据集中的整数是高端的。并且默认情况下在ByteBuffer中使用相同的顺序。

标签: java mnist


【解决方案1】:

我可以建议另一种解决方案。在这个class 你可以找到这些行:

List<MNISTSample> trainSamples = MNISTSamples.stream(
        trainLabelPath, 
        trainImagePath
    ).limit(TRAIN_LIMIT).collect(Collectors.toList());

List<MNISTSample> testSamples = MNISTSamples.stream(
        testLabelPath, 
        testImagePath
    ).limit(TEST_LIMIT).collect(Collectors.toList());

还有 3 个额外的类,可以很容易地重复使用。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2016-01-24
    • 2020-04-24
    • 1970-01-01
    • 2020-03-11
    • 2018-07-06
    • 2012-01-04
    • 2019-09-15
    • 1970-01-01
    相关资源
    最近更新 更多