【发布时间】:2021-07-26 10:46:30
【问题描述】:
目前我正在尝试使用 allennlp 实现延迟加载,但不能。 我的代码如下。
def biencoder_training():
params = BiEncoderExperiemntParams()
config = params.opts
reader = SmallJaWikiReader(config=config)
# Loading Datasets
train, dev, test = reader.read('train'), reader.read('dev'), reader.read('test')
vocab = build_vocab(train)
vocab.extend_from_instances(dev)
# TODO: avoid memory consumption and lazy loading
train, dev, test = list(reader.read('train')), list(reader.read('dev')), list(reader.read('test'))
train_loader, dev_loader, test_loader = build_data_loaders(config, train, dev, test)
train_loader.index_with(vocab)
dev_loader.index_with(vocab)
embedder = emb_returner()
mention_encoder, entity_encoder = Pooler_for_mention(word_embedder=embedder), \
Pooler_for_cano_and_def(word_embedder=embedder)
model = Biencoder(mention_encoder, entity_encoder, vocab)
trainer = build_trainer(lr=config.lr,
num_epochs=config.num_epochs,
model=model,
train_loader=train_loader,
dev_loader=dev_loader)
trainer.train()
return model
当我注释掉 train, dev, test = list(reader.read('train')), list(reader.read('dev')), list(reader.read('test')) 时,迭代器不起作用,并且使用 0 个样本进行训练。
Building the vocabulary
100it [00:00, 442.15it/s]01, 133.57it/s]
building vocab: 100it [00:01, 95.84it/s]
100it [00:00, 413.40it/s]
100it [00:00, 138.38it/s]
You provided a validation dataset but patience was set to None, meaning that early stopping is disabled
0it [00:00, ?it/s]
0it [00:00, ?it/s]
我想知道是否有任何解决方案可以避免这种情况。 谢谢。
5 月 5 日增补。
目前我试图避免在训练模型之前将所有样本数据都放在内存之上。
所以我将 _read 方法实现为生成器。我的理解是,通过调用这个方法并用SimpleDataLoader包装起来,我实际上可以将数据传递给模型。
在 DatasetReader 中,_read 方法的代码如下所示。据我了解,这是一个避免内存消耗的生成器。
@overrides
def _read(self, train_dev_test_flag: str) -> Iterator[Instance]:
'''
:param train_dev_test_flag: 'train', 'dev', 'test'
:return: list of instances
'''
if train_dev_test_flag == 'train':
dataset = self._train_loader()
random.shuffle(dataset)
elif train_dev_test_flag == 'dev':
dataset = self._dev_loader()
elif train_dev_test_flag == 'test':
dataset = self._test_loader()
else:
raise NotImplementedError(
"{} is not a valid flag. Choose from train, dev and test".format(train_dev_test_flag))
if self.config.debug:
dataset = dataset[:self.config.debug_data_num]
for data in tqdm(enumerate(dataset)):
data = self._one_line_parser(data=data, train_dev_test_flag=train_dev_test_flag)
yield self.text_to_instance(data)
另外,build_data_loaders 实际上看起来像这样。
def build_data_loaders(config,
train_data: List[Instance],
dev_data: List[Instance],
test_data: List[Instance]) -> Tuple[DataLoader, DataLoader, DataLoader]:
train_loader = SimpleDataLoader(train_data, config.batch_size_for_train, shuffle=False)
dev_loader = SimpleDataLoader(dev_data, config.batch_size_for_eval, shuffle=False)
test_loader = SimpleDataLoader(test_data, config.batch_size_for_eval, shuffle=False)
return train_loader, dev_loader, test_loader
但是,出于某种我不知道的原因,这段代码不起作用。
def biencoder_training():
params = BiEncoderExperiemntParams()
config = params.opts
reader = SmallJaWikiReader(config=config)
# Loading Datasets
train, dev, test = reader.read('train'), reader.read('dev'), reader.read('test')
vocab = build_vocab(train)
vocab.extend_from_instances(dev)
train_loader, dev_loader, test_loader = build_data_loaders(config, train, dev, test)
train_loader.index_with(vocab)
dev_loader.index_with(vocab)
embedder = emb_returner()
mention_encoder, entity_encoder = Pooler_for_mention(word_embedder=embedder), \
Pooler_for_cano_and_def(word_embedder=embedder)
model = Biencoder(mention_encoder, entity_encoder, vocab)
trainer = build_trainer(lr=config.lr,
num_epochs=config.num_epochs,
model=model,
train_loader=train_loader,
dev_loader=dev_loader)
trainer.train()
return model
在此代码中,SimpleDataLoader 将按原样包装生成器类型。我想做 allennlp 在 0.9 版本中做的延迟加载。
但是这段代码迭代训练了 0 个实例,所以目前我已经添加了
train, dev, test = list(reader.read('train')), list(reader.read('dev')), list(reader.read('test'))
之前
train_loader, dev_loader, test_loader = build_data_loaders(config, train, dev, test).
而且它有效。但这意味着在所有实例都在内存中之前,我无法训练或评估模型。相反,我希望每个批次仅在需要训练时才被调用到内存中。
【问题讨论】:
标签: nlp pytorch bert-language-model allennlp