简介
这个有点开放,但让我们试试,如果我在某个地方错了也请纠正我。
到目前为止,我一直是先将数据导出到文件系统,将子文件夹命名为文档的类。
IMO 这是不明智的,因为:
- 您实际上是在复制数据
- 只要您想训练一个新的只给定代码和数据库的操作,就必须重复此操作
- 您可以一次访问多个数据点并将它们缓存在 RAM 中以供以后重复使用,而无需从硬盘驱动器多次读取(这很繁重)
我说的对吗?直接连接MongoDB有意义吗?
鉴于上述情况,可能是的(尤其是在明确和可移植的实现方面)
或者是否有理由不这样做(例如,数据库通常会变慢等)?
AFAIK DB 在这种情况下不应该变慢,因为它会缓存对它的访问,但不幸的是我不是数据库专家。许多提高访问速度的技巧都是开箱即用的数据库。
可以以某种方式预取数据吗?
是的,如果您只想获取数据,您可以一次性加载大部分数据(例如 1024 记录)并从中返回批量数据(例如 batch_size=128)
实施
如何实现 PyTorch DataLoader?我在网上只找到了很少的代码 sn-ps([1] 和 [2]),这让我对我的方法产生了怀疑。
我不确定你为什么要这样做。您应该选择torch.utils.data.Dataset,如您列出的示例所示。
我将从类似于here 的简单非优化方法开始,所以:
- 在
__init__ 中打开与数据库的连接,并在使用期间一直保持(我将从torch.utils.data.Dataset 创建一个上下文管理器,以便在时期结束后关闭连接)
-
我不会将结果转换为
list(尤其是因为明显的原因你不能将它放入 RAM 中),因为它没有考虑生成器的意义
- 我会在这个数据集中执行批处理(有一个参数
batch_sizehere)。
- 我不确定
__getitem__ 函数,但它似乎可以一次返回多个数据点,因此我会使用它,它应该允许我们使用num_workers>0(假设mycol.find(query) 返回相同的数据每次都订购)
鉴于此,我会做一些类似的事情:
class DatabaseDataset(torch.utils.data.Dataset):
def __init__(self, query, batch_size, path: str, database: str):
self.batch_size = batch_size
client = pymongo.MongoClient(path)
self.db = client[database]
self.query = query
# Or non-approximate method, if the approximate method
# returns smaller number of items you should be fine
self.length = self.db.estimated_document_count()
self.cursor = None
def __enter__(self):
# Ensure that this find returns the same order of query every time
# If not, you might get duplicated data
# It is rather unlikely (depending on batch size), shouldn't be a problem
# for 20 million samples anyway
self.cursor = self.db.find(self.query)
return self
def shuffle(self):
# Find a way to shuffle data so it is returned in different order
# If that happens out of the box you might be fine without it actually
pass
def __exit__(self, *_, **__):
# Or anything else how to close the connection
self.cursor.close()
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
# Read takes long, hence if you can load a batch of documents it should speed things up
examples = self.cursor[index * batch_size : (index + 1) * batch_size]
# Do something with this data
...
# Return the whole batch
return data, labels
现在批处理由DatabaseDataset 负责,因此torch.utils.data.DataLoader 可以拥有batch_size=1。您可能需要压缩额外的维度。
由于MongoDB 使用锁(这并不奇怪,但请参阅here)num_workers>0 应该不是问题。
可能的用法(示意图):
with DatabaseDataset(...) as e:
dataloader = torch.utils.data.DataLoader(e, batch_size=1)
for epoch in epochs:
for batch in dataloader:
# And all the stuff
...
dataset.shuffle() # after each epoch
记住这种情况下的混洗实现!(混洗也可以在上下文管理器中完成,您可能希望手动关闭连接或类似的东西)。