【问题标题】:Convert string to byte for pytorch loader将字符串转换为 pytorch 加载程序的字节
【发布时间】:2019-10-11 02:08:19
【问题描述】:

下载 pytorch 模型路径的方法不在我的控制范围内,我正在尝试找出一种将下载的字符串数据转换为字节数据的方法。下面的代码从 Dropbox 下载我保存的模型,并使用带有 utf-8 编码的字节对字符串进行编码。问题是,当我将 torch.load 与 BytesIO 一起使用时,我得到一个 UnpicklingError 和无效的加载键,'

    data = bytes(self.Download("https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"), 'utf-8')

    self.agent.local.load_state_dict(torch.load(BytesIO(data ), map_location=lambda storage, loc: storage))

在禁用请求之前,下面的代码运行良好,我现在正在尝试使用上面的方法。

    dropbox_url = "https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"

    data = requests.get(dropbox_url )

    self.agent.local.load_state_dict(torch.load(BytesIO(data.content), map_location=lambda storage, loc: storage))

我只需要想办法将字符串以正确的方式转换为字节数据。

【问题讨论】:

  • 是的,我必须将字节数据转换为 base64 并将文件保存在 Dropbox 上。使用内置方法下载后,我将 base64 转换回字节并且它工作了!

标签: python load byte pytorch encode


【解决方案1】:

我必须将字节数据转换为 base64 并以该格式保存文件。一旦我上传到 Dropbox 并使用内置方法下载,我将 base64 文件转换回字节并且它工作了!

import base64
from io import BytesIO

with open("checkpoint.pth", "rb") as f:
    byte = f.read(1)

# Base64 Encode the bytes
data_e = base64.b64encode(byte)

filename ='base64_checkpoint.pth'

with open(filename, "wb") as output:
    output.write(data_e)

# Save file to Dropbox

# Download file on server
b64_str= self.Download('url')

# String Encode to bytes
byte_data = b64_str.encode("UTF-8")

# Decoding the Base64 bytes
str_decoded = base64.b64decode(byte_data)

# String Encode to bytes
byte_decoded = str_decoded.encode("UTF-8")

# Decoding the Base64 bytes
decoded = base64.b64decode(byte_decoded)

torch.load(BytesIO(decoded))

【讨论】:

    猜你喜欢
    • 2019-10-10
    • 1970-01-01
    • 2022-01-18
    • 2014-03-09
    相关资源
    最近更新 更多