【发布时间】:2020-01-28 09:08:42
【问题描述】:
我正在为 mnist 数据集使用 tesorflow_datasets 库在 GCP AI 平台上进行训练。我正在使用 tf.gan 估计器。我编写了一个使用 tfds 库读取 mnist 数据的输入管道。
import tensorflow_datasets as tfds
ds = tfds.load('mnist', split=self.split, shuffle_files=self.shuffle)
我已经在实例上使用相同的“tensorflow_datasets”库训练了我的 gan 模型,并且模型训练良好。我已将代码打包到包中,以便在 AI Platform 上运行。在 AI Platform 上训练期间,训练卡住并显示警告,
Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data
directory. If you'd instead prefer to read directly from our public GCS bucket.
虽然训练卡住了,但消耗的 ML 单位不断增加。
【问题讨论】:
标签: python tensorflow generative-adversarial-network gcp-ai-platform-training