【发布时间】:2019-10-11 02:08:19
【问题描述】:
下载 pytorch 模型路径的方法不在我的控制范围内,我正在尝试找出一种将下载的字符串数据转换为字节数据的方法。下面的代码从 Dropbox 下载我保存的模型,并使用带有 utf-8 编码的字节对字符串进行编码。问题是,当我将 torch.load 与 BytesIO 一起使用时,我得到一个 UnpicklingError 和无效的加载键,'
data = bytes(self.Download("https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"), 'utf-8')
self.agent.local.load_state_dict(torch.load(BytesIO(data ), map_location=lambda storage, loc: storage))
在禁用请求之前,下面的代码运行良好,我现在正在尝试使用上面的方法。
dropbox_url = "https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"
data = requests.get(dropbox_url )
self.agent.local.load_state_dict(torch.load(BytesIO(data.content), map_location=lambda storage, loc: storage))
我只需要想办法将字符串以正确的方式转换为字节数据。
【问题讨论】:
-
是的,我必须将字节数据转换为 base64 并将文件保存在 Dropbox 上。使用内置方法下载后,我将 base64 转换回字节并且它工作了!
标签: python load byte pytorch encode