好的,我有一个工作示例,这几乎是我所希望的。以下代码以我想要的方式生成批处理,但它需要使用占位符将数据传入和传出 TF 会话。我希望能够完全从 TF 图中构建这些批次。
希望我很傻,有人可以指出一些明显的解决方案。也请原谅骆驼案。
import tensorflow as tf
def buildBatch(seqLength, batchSize):
def lineToSequence(line):
line = tf.expand_dims(line, axis=0)
line = tf.sparse_tensor_to_dense(tf.string_split(line), '_')
line = tf.concat([line, [['<GO>']]], 1)
return line
data = tf.contrib.data.TextLineDataset(['./exampleFile.txt'])
data = data.map(lambda line: lineToSequence(line))
iterator = data.make_initializable_iterator()
# Grab lines from the file until the the sequence length is met and shave off any extra
def getFixedLengthSequence(start):
c = lambda s: tf.shape(s)[1] < seqLength # while sequence is is too short
b = lambda s: tf.concat([s, iterator.get_next()], 1) # concatenate the next line
sentences = tf.while_loop(c, b, [start], back_prop=False, parallel_iterations=1,
shape_invariants=[tf.TensorShape([1, None])])
clippedToLength = tf.expand_dims(sentences[0, :seqLength], axis=0)
leftover = tf.expand_dims(sentences[0, seqLength:], axis=0)
return clippedToLength, leftover
# Placeholders pass in the start of each sequence (which are saved from the last batch)
startOfThisBatch = [tf.placeholder(tf.string, shape=[1,None]) for i in range(batchSize)]
# Capture what is leftover from each sequence so it can be passed in to start the next batch
startOfNextBatch = [tf.TensorArray(tf.string, size=1) for i in range(batchSize)]
# Build the batch
thisBatch = []
for i, seqStart in enumerate(startOfThisBatch):
seq, leftover = getFixedLengthSequence(seqStart)
thisBatch.append(seq)
startOfNextBatch[i] = startOfNextBatch[i].write(0, leftover)
thisBatch = tf.concat(thisBatch, axis=0)
startOfNextBatch = [b.read(0) for b in startOfNextBatch]
return thisBatch, startOfThisBatch, startOfNextBatch, iterator.initializer
def printBatch():
sequenceLength = 10
batchSize = 3
batch, startOfThisBatch, startOfNextBatch, iteratorInit = buildBatch(sequenceLength, batchSize)
# The very first batch starts with <GO> tokens
batchStarts = [[['<GO>']]]*batchSize
sv = tf.train.Supervisor()
with sv.managed_session() as sess:
sess.run(iteratorInit)
for b in range(4):
# Populate feed dict with the beginning of each sequence in the batch
feed = {}
for i in range(batchSize):
feed[startOfThisBatch[i]] = batchStarts[i]
# Call TF to get this batch and the starting sequences of the next batch
out, batchStarts = sess.run([batch, startOfNextBatch], feed_dict=feed)
print 'Batch', b, ':'
for seq in out:
print " ".join(seq)
print
printBatch()
结果:
Batch 0 :
<GO> A spokesman said the company has been affected by
<GO> Having a little flexibility on that issue would go
<GO> Long before the advent of e-commerce , Wal-Mart 's
Batch 1 :
the credit crunch in the United States . <GO> Abu
a long way to putting together a final package .
founder Sam Walton set out his vision for a successful
Batch 2 :
Dhabi is going ahead to build solar city and no
<GO> Her back was torn open , her liver was
retail operation : " We let folks know we 're
Batch 3 :
pollution city . <GO> Now it has 175 staging centers
ruptured , one of her lungs had collapsed and the
interested in them and that they 're vital to us--
请注意,每个句子都在下一批中继续。使用的示例文本文件来自1-billion word benchmark dataset,每行包含一个句子。