前言
- 在跑
Faspect
代码时,对transformer
系列的预训练模型加载方式比较好奇,因此记录
from transformers import AutoConfig, FlaxAutoModelForVision2Seq# Download configuration from huggingface.co and cache.config = AutoConfig.from_pretrained("bert-base-cased")model = FlaxAutoModelForVision2Seq.from_config(config)
在使用Huggingface
提供的transformer
系列模型时,会通过model.from_pretrained
函数来加载预训练模型。
from_pretrainde
函数原型为
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):# 从预训练的模型配置实例化预训练的 Pytorch 模型
对加载预训练模型地址的介绍
pretrained_model_name_or_path
:- 一个字符串,模型id,该模型在
huggingface.co
的模型仓库中存在。有效的模型id可以是在 root-level 的,比如bert-base-uncased
,或者是在一个用户或者组织名的命名空间下的,比如dbmdz/bert-base-german-cased
- 一个文件夹路径,该文件夹包含使用
save_pretrained()
保存的模型权重,比如./my_model_dir
- 指向
tensorflow index checkpoint file
的路径,eg../tf_model/model.ckpt.index
- 包含.
msgpack
格式的flax checkpoint file
的模型文件夹的路径
- 一个字符串,模型id,该模型在
提示:如果服务器上无法通过第一种形式访问 huggingface,可以先将在 huggingface 上找到对应的repo,下载下来之后,使用第二种方式加载模型。