【问题标题】:How can I import the MNIST dataset that has been manually downloaded?如何导入手动下载的 MNIST 数据集?
【发布时间】:2017-04-03 01:54:03
【问题描述】:

我一直在试验一个 Keras 示例,需要导入 MNIST 数据

from keras.datasets import mnist
import numpy as np
(x_train, _), (x_test, _) = mnist.load_data()

它会生成错误消息,例如Exception: URL fetch failure on https://s3.amazonaws.com/img-datasets/mnist.pkl.gz: None -- [Errno 110] Connection timed out

应该和我使用的网络环境有关。 有没有什么函数或者代码可以让我直接导入手动下载的MNIST数据集?

我尝试了以下方法

import sys
import pickle
import gzip
f = gzip.open('/data/mnist.pkl.gz', 'rb')
  if sys.version_info < (3,):
    data = pickle.load(f)
else:
    data = pickle.load(f, encoding='bytes')
f.close()
import numpy as np
(x_train, _), (x_test, _) = data

然后我收到以下错误消息

Traceback (most recent call last):
File "test.py", line 45, in <module>
(x_train, _), (x_test, _) = data
ValueError: too many values to unpack (expected 2)

【问题讨论】:

    标签: keras


    【解决方案1】:

    好吧,keras.datasets.mnist 文件 is really short。可以手动模拟同样的动作,即:

    1. https://s3.amazonaws.com/img-datasets/mnist.pkl.gz下载数据集
    2. .

      import gzip
      f = gzip.open('mnist.pkl.gz', 'rb')
      if sys.version_info < (3,):
          data = cPickle.load(f)
      else:
          data = cPickle.load(f, encoding='bytes')
      f.close()
      (x_train, _), (x_test, _) = data
      

    【讨论】:

    • 嗨 sygi,谢谢你的建议。但是,我收到了更新后的帖子中显示的错误消息。唯一和你不同的是我用泡菜。看起来它在加载数据时没有给我错误。
    • 我已经检查过了,它可以在我的系统上运行,pickle 和 cPickle 以及 python 2 和 3。你确定你有相同的文件 (md5 b39289ebd4f8755817b1352c8488b486)?
    • 它工作,不知道为什么它之前有错误消息。非常感谢。
    • 在我的情况下,添加这些导入 import sys; import pickle; import gzip; 并使用 pickle 而不是 cPickle - 我在 macOS Mojave 上使用 Python 3.6.7
    【解决方案2】:

    您不需要额外的代码,但可以告诉load_data 首先加载本地版本:

    1. 您可以从另一台具有适当(代理)访问权限的计算机上下载文件https://s3.amazonaws.com/img-datasets/mnist.npz(取自https://github.com/keras-team/keras/blob/master/keras/datasets/mnist.py),
    2. 将其复制到目录~/.keras/datasets/(在 Linux 和 macOS 上)
    3. 并使用正确的文件名运行load_data(path='mnist.npz')

    【讨论】:

    • 这正是我想要的,非常感谢 Tardis
    【解决方案3】:

    Keras 文件位于 Google Cloud Storage 中的新路径中(之前位于 AWS S3 中):

    https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
    

    使用时:

    tf.keras.datasets.mnist.load_data()

    您可以传递path 参数。

    load_data()会调用get_file()作为参数fname,如果路径是完整路径且文件存在,则不会下载。

    例子:

    # gsutil cp gs://tensorflow/tf-keras-datasets/mnist.npz /tmp/data/mnist.npz
    # python3
    >>> import tensorflow as tf
    >>> path = '/tmp/data/mnist.npz'
    >>> (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(path)
    >>> len(train_images)
    >>> 60000
    

    【讨论】:

    • load_data 函数没有路径参数了
    【解决方案4】:
    1. 下载文件 https://s3.amazonaws.com/img-datasets/mnist.npz
    2. mnist.npz移动到.keras/datasets/目录
    3. 加载数据

      import keras
      from keras.datasets import mnist
      
      (X_train, y_train), (X_test, y_test) = mnist.load_data()
      

    【讨论】:

      【解决方案5】:

      keras.datasets.mnist.load_data() 将尝试从远程存储库中获取,即使指定了本地文件路径。但是,加载下载文件最简单的解决方法是使用numpy.load()just like they do

      path = '/tmp/data/mnist.npz'
      
      import numpy as np
      
      with np.load(path, allow_pickle=True) as f:
          x_train, y_train = f['x_train'], f['y_train']
          x_test, y_test = f['x_test'], f['y_test']
      

      【讨论】:

        【解决方案6】:

        Gogasca 的回答稍作调整就对我有用。对于 Python 3.9,更改 ~/Library/Python/3.9/lib/python/site-packages/keras/datasets/mnist.py 中的代码,以便它使用路径变量作为完整路径而不是添加 origin_folder 可以将任何本地路径传递给下载的文件。

        1. 下载文件:https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
        2. 将其放入 ~/Library/Python/3.9/lib/python/site-packages/keras/datasets/ 或您喜欢的其他位置。
        3. 改变 ~/Library/Python/3.9/lib/python/site-packages/keras/datasets/mnist.py
        path = path
        
        """ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' """
        """ path = get_file(
        path,origin=origin_folder + 'mnist.npz',file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1') """
        
        with np.load(path, allow_pickle=True) as f:  # pylint:
            disable=unexpected-keyword-arg
            x_train, y_train = f['x_train'], f['y_train']
            x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)
        
        1. 使用以下代码加载数据:
        path = "/Users/username/Library/Python/3.9/lib/python/site-packages/keras/datasets/mnist.npz"
        (train_images, train_labels), (test_images, test_labels ) = mnist.load_data(path=path)```
        

        【讨论】:

          猜你喜欢
          • 2019-11-06
          • 2018-08-11
          • 1970-01-01
          • 2018-06-23
          • 2018-09-27
          • 2021-01-18
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多