要阅读它,我建议您将其转换为 tsv 文件,其中示例由空行分隔(a.k.a conll 格式),如下所示:
src_fp, tgt_fp = "source/file/path.txt", "target/file/path.tsv"
with open(src_fp) as src_f:
with open(tgt_fp, 'w') as tgt_f:
for line in src_f:
words = list(line.split(' '))[0::2]
tags = list(line.split(' '))[1::2]
for w, t in zip(words, tags):
tgt_f.write(w+'\t'+t+'\n')
tgt_f.write('\n')
然后,您将能够使用来自torchtext.datasets 的SequenceTaggingDataset 读取它,如下所示:
text_field, label_field = data.Field(), data.Field()
pos_dataset = torchtext.datasets.SequenceTaggingDataset(
path='data/pos/pos_wsj_train.tsv',
fields=[('text', text_field),
('labels', label_field)])
最后一步是创建词汇表并获取数据的迭代器:
text_field.build_vocab(pos_dataset)
train_iter = data.BucketIterator.splits(
(unsup_train, unsup_val, unsup_test), batch_size=MY_BATCH_SIZE, device=MY_DEVICE)
# using the iterator
for ex in self train_iter:
train(ex.text, ex.labels)
我建议您花点时间阅读有关上述函数的文档,以便您能够根据自己的需要调整它们(最大词汇量、是否打乱您的示例、序列长度等)。
对于构建用于分类的 RNN,the official pytorch tutorial 非常容易学习。所以我建议你从那里开始,将网络输入和输出从序列分类(每个文本跨度 1 个标签)调整为序列标记(每个标记 1 个标签)。