【发布时间】:2018-11-27 21:14:57
【问题描述】:
我了解 Dataset API 是一种迭代器,它不会将整个数据集加载到内存中,因此无法找到数据集的大小。我说的是存储在文本文件或 tfRecord 文件中的大型数据语料库。这些文件通常使用tf.data.TextLineDataset 或类似的东西读取。查找使用 tf.data.Dataset.from_tensor_slices 加载的数据集的大小很简单。
我询问数据集大小的原因如下: 假设我的数据集大小是 1000 个元素。批量大小 = 50 个元素。然后训练步骤/批次(假设 1 个 epoch)= 20。在这 20 个步骤中,我想以指数方式将我的学习率从 0.1 衰减到 0.01
tf.train.exponential_decay(
learning_rate = 0.1,
global_step = global_step,
decay_steps = 20,
decay_rate = 0.1,
staircase=False,
name=None
)
在上面的代码中,我有“和”想设置decay_steps = number of steps/batches per epoch = num_elements/batch_size。只有事先知道数据集中元素的数量,才能计算出这一点。
提前知道大小的另一个原因是使用tf.data.Dataset.take()、tf.data.Dataset.skip() 方法将数据分成训练集和测试集。
PS:我不是在寻找暴力方法,例如遍历整个数据集并更新计数器以计算元素数量或 putting a very large batch size and then finding the size of the resultant dataset 等。
【问题讨论】:
标签: python tensorflow tensorflow-datasets