fldev


什么是 hub

hub(modelzoo)主要用来调用其他人训练好的模型和参数
Facebook官方博客表示,PyTorch Hub是一个简易API和工作流程,为复现研究提供了基本构建模块,包含预训练模型库。
并且,PyTorch Hub还支持Colab,能与论文代码结合网站Papers With Code集成,用于更广泛的研究。

github:https://github.com/pytorch/hub
模型:https://pytorch.org/hub/research-models


使用示例

import torch
model = torch.hub.load(\'pytorch/vision:v0.4.2\', \'deeplabv3_resnet101\', pretrained=True)
model.eval()
# 下载会显示下载数据和进度
# Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/xx/.cache/torch/hub/v0.4.2.zip
# Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/xx/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/shushu/.cache/torch/hub/v0.4.2.zip
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/shushu/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth

HBox(children=(IntProgress(value=0, max=178728960), HTML(value=\'\')))
torch.hub.list(\'pytorch/vision:v0.4.2\')
Using cache found in C:\Users\Administrator/.cache\torch\hub\pytorch_vision_v0.4.2

[\'alexnet\',
 \'deeplabv3_resnet101\',
 \'densenet121\',
 \'densenet161\',
 \'densenet169\',
 \'densenet201\',
 \'fcn_resnet101\',
 \'googlenet\',
 \'inception_v3\',
 \'mobilenet_v2\',
 \'resnet101\',
 \'resnet152\',
 \'resnet18\',
 \'resnet34\',
 \'resnet50\',
 \'resnext101_32x8d\',
 \'resnext50_32x4d\',
 \'shufflenet_v2_x0_5\',
 \'shufflenet_v2_x1_0\',
 \'squeezenet1_0\',
 \'squeezenet1_1\',
 \'vgg11\',
 \'vgg11_bn\',
 \'vgg13\',
 \'vgg13_bn\',
 \'vgg16\',
 \'vgg16_bn\',
 \'vgg19\',
 \'vgg19_bn\',
 \'wide_resnet101_2\',
 \'wide_resnet50_2\']

# Download an example image from the pytorch website
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to(\'cuda\')
    model.to(\'cuda\')

with torch.no_grad():
    output = model(input_batch)[\'out\'][0]
output_predictions = output.argmax(0)
# create a color pallette, selecting a color for each class
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")

# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)

import matplotlib.pyplot as plt
plt.imshow(r)
plt.show()


1、查询可用的模型

用户可以使用torch.hub.list()这个API列出repo中所有可用的入口点。比如你想知道PyTorch Hub中有哪些可用的计算机视觉模型:

>>> torch.hub.list(\'pytorch/vision\')
>>>
[\'alexnet\',
\'deeplabv3_resnet101\',
\'densenet121\',
...
\'vgg16\',
\'vgg16_bn\',
\'vgg19\',
 \'vgg19_bn\']

2、加载模型

在上一步中能看到所有可用的计算机视觉模型,如果想调用其中的一个,也不必安装,只需一句话就能加载模型。

model = torch.hub.load(\'pytorch/vision\', \'deeplabv3_resnet101\', pretrained=True)

至于如何获得此模型的详细帮助信息,可以使用下面的API:

print(torch.hub.help(\'pytorch/vision\', \'deeplabv3_resnet101\'))

如果模型的发布者后续加入错误修复和性能改进,用户也可以非常简单地获取更新,确保自己用到的是最新版本:

model = torch.hub.load(..., force_reload=True)
对于另外一部分用户来说,稳定性更加重要,他们有时候需要调用特定分支的代码。例如pytorch_GAN_zoo的hub分支:

model = torch.hub.load(\'facebookresearch/pytorch_GAN_zoo:hub\', \'DCGAN\', pretrained=True, useGPU=False)

3、查看模型可用方法

从PyTorch Hub加载模型后,你可以用dir(model)查看模型的所有可用方法。以bertForMaskedLM模型为例:

>>> dir(model)
>>>
[\'forward\'
...
\'to\'
\'state_dict\',
]

forward

如果你对forward方法感兴趣,使用help(model.forward) 了解运行运行该方法所需的参数。

>>> help(model.forward)
>>>
Help on method forward in module pytorch_pretrained_bert.modeling:
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
...

支持 Colab

PyTorch Hub中提供的模型也支持Colab。

进入每个模型的介绍页面后,你不仅可以看到GitHub代码页的入口,甚至可以一键进入Colab运行模型Demo。



对于模型发布者

如果你希望把自己的模型发布到PyTorch Hub上供所有用户使用,可以去PyTorch Hub的GitHub页发送拉取请求。若你的模型符合高质量、易重复、最有利的要求,Facebook官方将会与你合作。

一旦拉取请求被接受,你的模型将很快出现在PyTorch Hub官方网页上,供所有用户浏览。

目前该网站上已经有18个提交的模型,英伟达率先提供支持,他们在PyTorch Hub已经发布了Tacotron2和WaveGlow两个TTS模型。

图片

发布模型的方法也是比较简单的,开发者只需在自己的GitHub存储库中添加一个简单的hubconf.py文件,在其中枚举运行模型所需的依赖项列表即可。

比如,torchvision中的hubconf.py文件是这样的:

# Optional list of dependencies required by the package
dependencies = [\'torch\']

from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
resnext50_32x4d, resnext101_32x8d
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2

Facebook官方向模型发布者提出了以下三点要求:

1、每个模型文件都可以独立运行和执行
2、不需要PyTorch以外的任何包
3、不需要单独的入口点,让模型在创建时可以无缝地开箱即用

Facebook还建议发布者最小化对包的依赖性,减少用户加载模型进行实验的阻力。


更多资料

分类:

技术点:

相关文章:

  • 2022-12-23
  • 2021-11-21
  • 2021-11-18
  • 2021-07-24
  • 2021-04-13
  • 2021-12-28
  • 2022-12-23
  • 2022-12-23
猜你喜欢
  • 2021-12-31
  • 2021-07-13
  • 2022-12-23
  • 2022-12-23
  • 2023-03-21
  • 2021-11-18
  • 2021-04-26
相关资源
相似解决方案