【发布时间】:2019-12-23 16:13:43
【问题描述】:
在 tensorflow2.0 中,我想训练一个带有 nce loss 的 skip-gram 模型。 tf.data.Dataset.from_tensor_slices() 不适合,因为输入文件非常大。所以我写了一个这样的数据集生成器类:
class DataSet:
""""""
def __init__(self, args, vocab):
self.args = args
self.vocab = vocab
def generator(self):
"""a generator function, it will return skip-gram sample or cbow sample"""
with open(self.args.input) as f_input:
for line in tqdm.tqdm(f_input.readlines()):
tokens = line.strip().split()
tokens_indices = self.vocab.indices(tokens)
for index, target_word in enumerate(tokens_indices):
context_words = list()
begin = index - self.args.window_size if index - self.args.window_size > 0 else 0
end = index + 1 + self.args.window_size if index + self.args.window_size + 1 < len(tokens_indices) else len(
tokens_indices)
context_words.extend(tokens_indices[begin:index])
context_words.extend(tokens_indices[index + 1:end])
if self.args.cbow > 0:
yield context_words, target_word
else:
for i in range(len(context_words)):
yield target_word, context_words[i]
def dataset(self):
"""Using tf.data.Dataset.from_generator() to return sample"""
if self.args.cbow:
dataset = tf.data.Dataset.from_generator(
self.generator,
(tf.int32, tf.int32),
(tf.TensorShape([None]), tf.TensorShape([]))
)
else:
dataset = tf.data.Dataset.from_generator(
self.generator,
(tf.int32, tf.int32),
(tf.TensorShape([]), tf.TensorShape([]))
)
return dataset
然后我用下面的代码测试我的代码:
dataset = DataSet(args, vocab).dataset()
iterator = dataset.make_one_shot_iterator()
for batch, (x,y) in enumerate(dataset.batch(128)):
pass
print(batch, x.shape, y.shape)
但是迭代所有行需要花费大量时间(在 MacBook pro 2012 中大约需要 10 分钟 / 15000 行)。有没有什么方法可以加快代码速度?
【问题讨论】:
标签: python tensorflow2.0 data-generation