【发布时间】:2019-06-23 00:22:22
【问题描述】:
我正在尝试在 CPU 和 CUDA 上运行代码。
创建对象时会出现问题,因为我需要知道预期的内容。
在创建之前,我需要确定计算机是否需要 CUDA 或 CPU 张量。
代码:
def initilize(self, input):
self.x = torch.nn.Parameter(torch.zeros((1,M))
def run(self,x,state):
B = torch.cat((self.x,h)
这个输出:
Error: 'Expected object of backend CUDA but got backend CPU for argument #1'
代码思路:
def initilize(self, input):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if (expecting_cuda == True):
self.x = torch.nn.Parameter(torch.zeros((1,M)).to(device))
else
self.x = torch.nn.Parameter(torch.zeros((1,M))
def run(self,h):
B = torch.cat((self.x,h)
问题: 如何弄清楚计算机期望什么?
限制: 我正在运行一个预定义的“检查”过程,所以我无法将参数发送到函数“initilize”中,其中包含有关 CUDA 或 CPU 的信息。
【问题讨论】:
标签: python-3.x pytorch