在通过 source_code 进一步挖掘后,似乎 view 是唯一在传递非连续输入时显式导致异常的函数。
人们会期待any operation using Tensor Views
有可能因非连续输入而失败。实际上,这些功能中的大部分或全部似乎是:
(a.) 实现支持非连续块(参见下面的示例),即张量迭代器可以处理指向内存中各种数据块的多个指针,可能以牺牲性能为代价,否则
(b.) 对.contiguous() 的调用包装了操作(在here 中为torch.tensor.diagflat() 显示了一个这样的示例)。 reshape 本质上是 contiguous() 的包装形式 view。
通过扩展,view 与 reshape 相比的主要优势似乎是当张量意外不连续时显式异常,而代码以性能为代价默默地处理这种差异。
这个结论基于:
- 使用非连续输入测试所有 Tensor View 操作。
- 其他感兴趣的非张量视图函数的源代码分析(例如Conv1D,其中包括在所有重要输入情况下都需要调用
contiguous)。
- 从 pytorch 的设计理念推断为一种简单、有时缓慢、易于使用的语言。
- 在Pytorch Discuss 上交叉发帖。
- 对涉及非连续错误的网络报告错误进行了广泛审查,所有这些错误都围绕着对
view 的有问题的调用。
我没有全面测试所有 pytorch 函数,因为有数千个。
(a.)示例:
import torch
import numpy
import time
# allocation
start = time.time()
test = torch.rand([10000,1000,100])
torch.cuda.synchronize()
end = time.time()
print("Allocation took {} sec. Data is at address {}. Contiguous:
{}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# view of a contiguous tensor
start = time.time()
test.view(-1)
torch.cuda.synchronize()
end = time.time()
print("view() took {} sec. Data is at address {}. Contiguous:
{}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# diagonal() on a contiguous tensor
start = time.time()
test.diagonal()
torch.cuda.synchronize()
end = time.time()
print("diagonal() took {} sec. Data is at address {}. Contiguous:
{}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# Diagonal and a few tensor view ops on a non-contiguous tensor
test = test[::2,::2,::2] # indexing is a Tensor View op
resulting in a non-contiguous output
print(test.is_contiguous()) # False
start = time.time()
test = test.unsqueeze(-1).expand([test.shape[0],test.shape[1],test.shape[2],100]).diagonal()
torch.cuda.synchronize()
end = time.time()
print("non-contiguous tensor ops() took {} sec. Data is at
address {}. Contiguous: {}".format(end -
start,test.storage().data_ptr(),test.is_contiguous()))
# reshape, which requires a tensor copy operation to new memory
start = time.time()
test = test.reshape(-1) + 1.0
torch.cuda.synchronize()
end = time.time()
print("reshape() took {} sec. Data is at address {}. Contiguous: {}".format(end - start,test.storage().data_ptr(),test.is_contiguous()))
以下是输出:
Allocation took 4.269254922866821 sec. Data is at address 139863636672576. Contiguous: True
view() took 0.0002810955047607422 sec. Data is at address 139863636672576. Contiguous: True
diagonal() took 6.532669067382812e-05 sec. Data is at address 139863636672576. Contiguous: True
False
non-contiguous tensor ops() took 0.00011277198791503906 sec. Data is at address 139863636672576. Contiguous: False
reshape() took 0.13828253746032715 sec. Data is at address 94781254337664. Contiguous: True
块 4 中的一些张量视图操作是在非连续输入张量上执行的。该操作运行没有错误,将数据保持在相同的内存地址中,并且比需要复制到新内存地址的操作(例如块 5 中的reshape)运行速度相对更快。因此,这些操作似乎是以一种无需数据副本即可处理非连续输入的方式实现的。