【问题标题】:PyTorch tensor slice and memory usagePyTorch 张量切片和内存使用情况
【发布时间】:2020-05-22 21:27:08
【问题描述】:
import torch
T = torch.FloatTensor(range(0,10 ** 6)) # 1M

#case 1:
torch.save(T, 'junk.pt')
# results in a 4 MB file size

#case 2:
torch.save(T[-20:], 'junk2.pt')
# results in a 4 MB file size

#case 3:
torch.save(torch.FloatTensor(T[-20:]), 'junk3.pt')
# results in a 4 MB file size

#case 4:
torch.save(torch.FloatTensor(T[-20:].tolist()), 'junk4.pt')
# results in a 405 Bytes file size

我的问题是:

  1. 在案例 3 中,生成的文件大小似乎令人惊讶,因为我们正在创建一个新张量。为什么这个新张量不只是切片?

  2. 案例 4 是仅保存张量的一部分(切片)的最佳方法吗?

  3. 更一般地说,如果我想通过删除其值的前半部分来“修剪”一个非常大的一维张量以节省内存,我是否必须像案例 4 那样继续,或者是否有更直接的方法?并且不涉及创建 python 列表的计算成本更低的方式。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    (i) 在案例 3 中,生成的文件大小似乎令人惊讶,因为我们正在创建一个新张量。为什么这个新张量不只是切片?

    切片创建张量的视图,该视图共享基础数据,但包含有关用于可见数据的内存偏移的信息。这避免了必须频繁复制数据,从而使许多操作更加高效。有关受影响操作的列表,请参阅 PyTorch - Tensor Views

    您正在处理基础数据很重要的少数情况之一。保存张量需要保存底层数据,否则偏移量将不再有效。

    torch.FloatTensor 不会创建张量的副本,如果没有必要的话。您可以验证它们的底层数据是否仍然相同(它们具有完全相同的内存位置):

    torch.FloatTensor(T[-20:]).storage().data_ptr() == T.storage().data_ptr()
    # => True
    

    (ii) 情况 4 是只保存张量的一部分(切片)的最佳方法吗?

    (iii) 更一般地说,如果我想通过删除其值的前半部分来“修剪”一个非常大的一维张量以节省内存,我是否必须像案例 4 那样继续,或者是否存在一种更直接且计算成本更低的方法,不涉及创建 python 列表。

    您很可能无法避免复制切片的数据,但至少您可以避免从它创建一个 Python 列表并从列表中创建一个新的张量,而是使用 torch.Tensor.clone

    torch.save(T[-20:].clone(), 'junk5.pt')
    

    【讨论】:

      猜你喜欢
      • 2020-10-17
      • 1970-01-01
      • 2020-08-01
      • 1970-01-01
      • 2020-05-02
      • 2011-06-24
      • 2020-02-16
      • 2017-12-01
      相关资源
      最近更新 更多