【发布时间】:2021-02-03 15:45:39
【问题描述】:
在带有 Tensorflow-2.0.0 的 jupyter notebook 上,以这种方式执行了 80-10-10 的 train-validation-test 拆分:
import tensorflow_datasets as tfds
from os import getcwd
splits = tfds.Split.ALL.subsplit(weighted=(80, 10, 10))
filePath = f"{getcwd()}/../tmp2/"
splits, info = tfds.load('fashion_mnist', with_info=True, as_supervised=True, split=splits, data_dir=filePath)
但是,当尝试在本地运行相同的代码时,我得到了错误
AttributeError: type object 'Split' has no attribute 'ALL'
我已经看到我可以通过这种方式创建两个集合:
splits, info = tfds.load('fashion_mnist', with_info=True, as_supervised=True, split=['train[:80]','test[80:90]'], data_dir=filePath)
但我不知道如何添加第三组。
【问题讨论】:
标签: python tensorflow tensorflow-datasets train-test-split