方便的方式
根据 this issue 上的讨论,截至 2022 年 12 月 22 日,没有方便的方法将默认设备设置为 MPS。
不方便的方式
您可以通过拦截对 tensor constructors 的调用来实现“我不想为张量构造函数指定 device=,只需使用 MPS”的目标:
class MPSMode(torch.overrides.TorchFunctionMode):
def __init__(self):
# incomplete list; see link above for the full list
self.constructors = {getattr(torch, x) for x in "empty ones arange eye full fill linspace rand randn randint randperm range zeros tensor as_tensor".split()}
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in self.constructors:
if 'device' not in kwargs:
kwargs['device'] = 'mps'
return func(*args, **kwargs)
# sensible usage
with MPSMode():
print(torch.empty(1).device) # prints mps:0
# sneaky usage
MPSMode().__enter__()
print(torch.empty(1).device) # prints mps:0
推荐方式:
我倾向于将您的设备放在笔记本顶部的配置中并明确使用它:
class Conf: dev = torch.device("mps")
# ...
a = torch.randn(1, device=Conf.dev)
这需要您在整个代码中键入 device=Conf.dev。但是您可以轻松地将您的代码切换到不同的设备,并且您无需担心任何隐式全局状态。