这是一个解决方案,要求您使用“控制输入”来选择要使用的批次,然后您根据首先使用哪个数据集来决定这一点。这可以使用抛出的异常来检测。
为了解释这个解决方案,我将首先提出一个不起作用的尝试。
尝试的解决方案 #1
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next(name='batch1')
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
# this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
lambda:batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
lambda:batch1,
lambda:batch2)) # else, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
这个解决方案不起作用,因为对于which_batch 的任何值,tf.cond() 命令将评估其分支的所有前身(请参阅this answer)。因此,即使 which_batch 的值为 1,batch2 也会被计算并抛出 OutOfRangeError。
尝试的解决方案 #2
这个问题可以通过将batch1、batch2和batch12的定义移动到函数中来解决。
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
def get_batch1():
batch1 = iter1.get_next(name='batch1')
return batch1
def get_batch2():
batch2 = iter2.get_next(name='batch2')
return batch2
def get_batch12():
batch1 = iter1.get_next(name='batch1_')
batch2 = iter2.get_next(name='batch2_')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
# this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
get_batch1,
get_batch2)) # elif `which_batch`==2, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
但是,这也不起作用。原因是在形成batch12 并消耗数据集ds2 的那一步,然后我们从数据集ds1 中取出批处理并“丢弃”它而不使用它。
解决方案
我们需要一种机制来确保在使用其他数据集的情况下不会“丢弃”任何批次。我们可以通过定义一个变量来做到这一点,该变量将被分配当前批次的ds1,但仅在尝试获得batch12之前立即。否则,此变量将保留其先前的值。然后,如果 batch12 由于 ds1 被消耗而失败,那么这个分配将失败并且 batch2 没有被丢弃,我们下次可以使用它。否则,如果batch12 由于ds2 被消耗而失败,那么我们在我们定义的变量中拥有batch1 的备份,使用此备份后我们可以继续获取batch1。
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
# this variable will store a backup of `batch1`, in case it is dropped
batch1_backup = tf.Variable(0, trainable=False, validate_shape=False)
def get_batch12():
batch1 = iter1.get_next(name='batch1')
# form the combined batch `batch12` only after backing-up `batch1`
with tf.control_dependencies([tf.assign(batch1_backup, batch1, validate_shape=False)]):
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
def get_batch1():
batch1 = iter1.get_next()
return batch1
def get_batch2():
batch2 = iter2.get_next()
return batch2
# this is a "control" placeholder. Its value determines whether to use `batch12`, `batch1_backup`, `batch1`, or `batch2`
which_batch = tf.Variable(0,trainable=False)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1_backup`
lambda:batch1_backup,
lambda:tf.cond(tf.equal(which_batch,2), # elif `which_batch`==2, use `batch1`
get_batch1,
get_batch2))) # else, use `batch2`
sess = tf.Session()
sess.run(tf.global_variables_initializer())
which = 0 # this value will be fed into the control placeholder
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
# if just used `batch1_backup`, proceed with `batch1`
if which==1:
which = 2
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which == 0:
if 'batch2' in e.op.name:
which = 1
else:
which = 3
else:
break