您的位置:首页 > 科技 >

Facebook 发布 PyTorch Hub:一行代码实现经典模型调用!

作者 | Team PyTorch

译者 | Monanfei

责编 | 夕颜

出品 | AI 科技大本营(ID: rgznai100)

6 月 11 日,Facebook PyTorch 团队推出了全新 API PyTorch Hub,提供模型的基本构建模块,用于提高机器学习研究的模型复现性。PyTorch Hub 包含一个经过预训练的模型库,内置对 Colab 的支持,而且能够与 Papers With Code 集成。另外重要的一点是,它的整个工作流程大大简化。

简化到什么程度呢?Facebook 首席 AI 科学家 Yann LeCun 兼图灵奖图灵奖得主 Yann LeCun 发表 Twitter 强烈推荐,使用 PyTorch Hub,无论是 ResNet、BERT、GPT、VGG、PGAN 还是 MobileNet 等经典模型,只需输入一行代码,就能实现一键调用。

Twitter 一发,立刻引来众多网友评论点赞,并有网友表示希望看到 PyTorch Hub 与 TensorFlow Hub 的区别。

这个模型聚合中心到底如何呢?我们来一探究竟。

模型复现是许多领域的基本要求,尤其是在与机器学习相关的邻域中。然而,许多机器学习相关的出版物,要么不可复现,要么难以复现。随着出版物数量的不断增长(包括在 arXiv 上发表的成数万篇论文,以及会议提交的大量论文),模型复现比以往任何时候都更加重要。虽然这些出版物大多数都包含代码和训练好的模型,但如果用户想复现这些模型,还需要做大量的额外的工作。

今天,我们很荣幸地宣布推出 PyTorch Hub,它是一个非常简单的 API,并且具有极其简单的工作流程。它提供模型的基本构建模块,用于提高机器学习研究的模型复现性。PyTorch Hub 包含一个经过预训练的模型库,专门用于促进研究的可重复性和快速开展新的研究。PyTorch Hub 内置了对 Colab 的 支持,并且能够与 Papers With Code 集成。目前 PyTorch Hub 已包含一系列广泛的模型,包括分类器和分割器、生成器、变换器等。

【开发者】发布模型

通过添加简单 hubconf.py 文件,开发者能够将预训练的模型(模型定义和预训练的权重)发布到 GitHub 仓库中。该文件提供了所支持模型的枚举,以及运行这些模型的依赖环境列表。相关的例子可以参见 torchvision、 huggingface-bert 和 gan-model-zoo 仓库。

让我们看看最简单的例子:torchvision ’ s hubconf.py:

1# Optional list of dependencies required by the package

2dependencies = [ "torch" ]

4from torchvision.models.alexnet import alexnet

5from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161

6from torchvision.models.inception import inception_v3

7from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,

8resnext50_32x4d, resnext101_32x8d

9from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1

10from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn

11from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101

12from torchvision.models.googlenet import googlenet

13from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0

14from torchvision.models.mobilenet import mobilenet_v2

在 torchvision 中,各模型具有如下性质:

每个模型文件都能作为函数调用,或者独立执行;

除了 PyTorch 之外(在 hubconf.py 中编码为 dependencies [ "torch" ] ),它们不需要任何其他包的支持;

不需要单独的接入点,因为模型在创建时可以无缝地接入。

PyTorch Hub 将包的依赖性降到了最小,当使用者加载模型并立即进行实验时,该特性能够提高用户体验。

接下来我们看一个较为复杂的例子:HuggingFace ’ s BERT 模型,下面是该模型的 hubconf.py:

1dependencies = [ "torch", "tqdm", "boto3", "requests", "regex" ]

3from hubconfs.bert_hubconf import (

4 bertTokenizer,

5 bertModel,

6 bertForNextSentencePrediction,

7 bertForPreTraining,

8 bertForMaskedLM,

9 bertForSequenceClassification,

10 bertForMultipleChoice,

11 bertForQuestionAnswering,

12 bertForTokenClassification

13 )

每个模型都需要创建一个接入点,一下代码用于指定 bertForMaskedLM 模型的接入点,并返回预训练的模型权重。

1def bertForMaskedLM ( *args, **kwargs ) :

2 """

3 BertForMaskedLM includes the BertModel Transformer followed by the

4 pre-trained masked language modeling head.

5 Example:

6 ...

7 """

8 model = BertForMaskedLM.from_pretrained ( *args, **kwargs )

9 return model

这些接入点可以作为复杂模型的包装器,它们能够提供干净且一致的帮助文档字符串,支持使用者选择是否下载预训练权重(例如 pretrained=True),并且具有其它的特定功能,例如可视化。

创建好 hubconf.py 后,可以根据此模板发送 github 推送请求 。PyTorch Hub 的目标是为研究复现提供高质量、易于重复、高效的模型。因此,我们可能会与开发者合作完善推送请求,并在某些情况下拒绝发布一些低质量的模型。一旦我们接受了开发者的推送请求,开发者的模型将很快出现在 Pytorch 中心网页上,从而供所有的用户浏览。

【用户】工作流程

作为用户,PyTorch Hub 提供非常简单的工作流程,用户只需要按照以下三个步骤执行即可:(1)探索有价值的模型;(2)加载模型;(3)了解任何给定模型的可用方法。接下来,让我们分别看看每个步骤。

探索可用的接入点

用户可以使用 torch.hub.list ( ) 列出仓库中所有可用的接入点。

1>>> torch.hub.list ( "pytorch/vision" )

2>>>

3 [ "alexnet",

4"deeplabv3_resnet101",

5"densenet121",

6...

7"vgg16",

8"vgg16_bn",

9"vgg19",

10 "vgg19_bn" ]

值得注意的是,PyTorch Hub 还允许辅助接入点(除了预训练模型)。例如,bertTokenizer 可以用于 BERT 模型中的预处理,这使得用户的工作流程更加顺畅。

加载模型

现在,我们已经知道了 Hub 中可用的模型,那么用户便能够使用 torch.hub.load ( ) 来加载模型接入点。该命令无需安装其他依赖包,此外,torch.hub.help ( ) 提供了如何实例化模型的信息。

1print ( torch.hub.help ( "pytorch/vision", "deeplabv3_resnet101" ) )

2model = torch.hub.load ( "pytorch/vision", "deeplabv3_resnet101", pretrained=True )

由于开发者会不断修复 bug,改进模型,因此 PyTorch Hub 也提供了便捷的方法,使得用户可以非常容易地获取最新的更新:

1model = torch.hub.load ( ..., force_reload=True )

我们相信,这些功能可以让开发者更加专注于他们的研究,而不用为这些繁琐的事情浪费时间。同时,这能够确保用户享受最新的模型。

从另一个方面来看,对用户而言,稳定性是非常重要的。因此,一些开发者会在其他分支上推送稳定的模型,而不是在 mater 分支上推送,这样能够保证代码的稳定性。例如,pytorch_GAN_zoo 在 hub 分支上提供稳定的版本。

1model = torch.hub.load ( "facebookresearch/pytorch_GAN_zoo:hub", "DCGAN", pretrained=True, useGPU=False )

请注意,hub.load ( ) 中的 *args 和 **kwargs 用于实例化模型。在上面的例子中,pretrained=True 以及 useGPU=False 会被传递给模型的接入点。

探索加载的模型

从 PyTorch Hub 加载模型后,用户可以使用下面的工作流程找出模型的可用方法,并更好地了解运行该模型所需的参数。

dir ( model ) 用于查看模型的所有可用方法。接下来,让我们看看 bertForMaskedLM 可用的方法。

1>>> dir ( model )

2>>>

3 [ "forward"

4...

5"to"

6"state_dict",

help ( model.forward ) 用于展示模型运行所需的参数:

1>>> help ( model.forward )

2>>>

3Help on method forward in module pytorch_pretrained_bert.modeling:

4forward ( input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None )

5...

在 BERT 和 DeepLabV3 页面中,用户可以详细了解这些模型的使用方法。

其他探索的方式

PyTorch Hub 中提供的模型支持 Colab,并且直接链接在 Papers With Code 上,只需单击即可使用。下面是一个很好的入门示例。

其他资源

PyTorch Hub API 文档(https://pytorch.org/docs/stable/hub.html)

提交模型(https://github.com/pytorch/hub)

可用模型的更多信息(https://pytorch.org/hub)

探索更多模型(https://paperswithcode.com/)

原文链接:

https://pytorch.org/blog/towards-reproducible-research-with-pytorch-hub/

【END】

6 月 29-30 日,2019 以太坊技术及应用大会特邀以太坊创始人 V 神与以太坊基金会核心成员,以及海内外知名专家齐聚北京,聚焦前沿技术,把握时代机遇,深耕行业应用,共话以太坊 2.0 新生态。扫码即享优惠购票!

热 文 推 荐

☞ 长沙到底有没有互联网?

☞ 618 前夕,不谈促销,京东云带你聊聊技术……

☞ Google Chrome,另类的邪恶垄断?

☞ 9 年前他用 1 万个比特币买了两个披萨 , 9 年后他把当年的代码卖给了苹果,成为了 GPU 挖矿之父

☞ 17 岁的程序员告诉你关于编程的 7 个重要教训!

☞ Bert 时代的创新:Bert 在 NLP 各领域的应用进展 | 技术头条

☞ Lambda 表达式有何用处?

☞ Python 编写循环的两个建议 | 鹅厂实战

☞ 漫威金刚狼男主弃影炒币了?

☞" 是!互联网从此没有 BAT!"

你点的每个 " 在看 ",我都认真当成了喜欢

相关阅读