【发布时间】:2020-03-23 18:45:46
【问题描述】:
我了解您可以将批量大小分配给数据集并返回一个新的数据集对象。是否有 API 可以查询给定数据集对象的批量大小?
我正在尝试在以下位置查找电话:
【问题讨论】:
标签: tensorflow
我了解您可以将批量大小分配给数据集并返回一个新的数据集对象。是否有 API 可以查询给定数据集对象的批量大小?
我正在尝试在以下位置查找电话:
【问题讨论】:
标签: tensorflow
当您调用.batch(32) 方法时,它会返回一个tensorflow.python.data.ops.dataset_ops.BatchDataset 对象。如Tensorflow Documentation 中所述,这种对象具有称为._batch_size 的私有属性,其中包含一个batch_size 的张量。
在 tensorflow 2.X 中,您只需调用此张量的 .numpy() 方法即可将其转换为 numpy.int64 类型。
在 tensorflow 1.X 中需要调用 .eval() 方法。
【讨论】:
我不知道您是否可以将其作为属性获取,但您可以遍历数据集一次并打印形状:
# create a simple tf.data.Dataset with batchsize 3
import tensorflow as tf
f = tf.data.Dataset.range(10).batch(3) # Dataset with batch_size 3
# iterating once
for one_batch in f:
print('batch size:', one_batch.shape[0])
break
如果您知道您的数据集也有目标/标签,则必须进行如下迭代:
# iterating once
for one_batch_x, one_batch_y in f:
print('batch size:', one_batch_x.shape[0])
break
在这两种情况下,它都会打印:
batch size: 3
【讨论】:
在 TensorFlow 1.* 中通过 dataset._dataset._batch_size 访问 batch_size:
import tensorflow as tf
import numpy as np
print(tf.__version__) # 1.14.0
dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)
with tf.compat.v1.Session() as sess:
batch_size = sess.run(dataset._dataset._batch_size)
print(batch_size) # 10
在 TensorFlow 2 中,您可以通过 dataset._batch_size 访问:
import tensorflow as tf
import numpy as np
print(tf.__version__) # 2.0.1
dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)
batch_size = dataset._batch_size.numpy()
print(batch_size) # 10
【讨论】: