pytorch中的torch.hub.load():以vggish为例

pytorch提供了torch.hub.load()函数加载模型,该方法可以从网上直接下载模型或是从本地加载模型。官方文档

cpp 复制代码
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)

参数说明:
repo_or_dir( string ) 如果是 'github', 这应该对应于格式为可选的ref(标记或分支),例如 'pytorch/vision:0.10'。 如果是"local",则它应该是本地目录的路径。sourcerepo_owner/repo_name[:ref]refmainmastersource
model ( string ) 在dir的hubconf.py
*args(可选)callable 的相应参数。
source ( string , optional ) 'github' 或 'local'。指定如何解释repo_or_dir。
force_reload ( bool , optional ) 是否无条件强制重新下载github repo。默认为False,即下一次直接从本地读取。
verbose ( bool , optional ) 如果False,静音有关命中本地缓存的消息。请注意,有关首次下载的消息无法静音。如果source = 'local'没有任何影响。默认为True。
skip_validation ( bool , optional ) 如果False,torchhub 将检查github参数指定的分支或提交是否正确属于 repo 所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN环境变量来指定非默认 GitHub 令牌 。默认为False。
**kwargs(可选) 可调用的对应kwargs。

加载vggish预训练模型

vggish模型用于音频分类模型的特征嵌入,预训练的pytorch版本:harritaylor/torchvggish,该版本的权重直接从tensorflow模型移植,因此使用"torchvggish"创建的嵌入将是相同的。

官方的加载模型示例代码:

cpp 复制代码
import torch

model = torch.hub.load('harritaylor/torchvggish', 'vggish')
model.eval()

# Download an example audio file
import urllib
url, filename = ("http://soundbible.com/grab.php?id=1698&type=wav", "bus_chatter.wav")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

model.forward(filename)

运行这个代码块会自动从github上加载预训练的torchvggish模型和与训练权重,如果在hub里提示缺少什么包直接装就可以。

这里要提一个问题:如果在下载模型的中途中断下载,那么下次运行这个代码的时候可能会报错:

cpp 复制代码
RuntimeError: unexpected EOF, expected 198783261 more bytes. The file might be corrupted.

这个问题是说从本地加载的文件是残缺的,不完整。因为该方法会首先从本地load文件,而你之前下载的文件没下完,设置force_reload为True也没啥用,需要找到你本地下载下来的预训练模型pth文件并且删掉,就可以重新下载了。

相关推荐
AI医影跨模态组学8 小时前
Lancet Digital Health(IF=24.1)德国德累斯顿工业大学医学院:深度学习评估结直肠癌的基因型-表型相关性
人工智能·深度学习·论文·医学影像·影像组学
星恒随风8 小时前
从零开始理解 CNN(上):为什么图像任务需要卷积神经网络?
人工智能·笔记·神经网络·学习·cnn
YOLO数据集集合8 小时前
滑坡智能识别|遥感卫星无人机多源影像数据集|深度学习语义分割开源基准
人工智能·深度学习·yolo·目标检测·视觉检测·无人机
星恒随风8 小时前
从零开始理解 CNN(下):拆开卷积层、池化层、通道数和训练流程
人工智能·笔记·深度学习·神经网络·学习·cnn
蔡俊锋8 小时前
AI时代,是时候越狱了
人工智能·ai 越狱
有为少年8 小时前
深度学习中的隐式层
人工智能·深度学习·神经网络·线性代数·机器学习·优化算法·深度隐式层
草莓啵啵~8 小时前
pywinauto-打开程序+连接已打开的程序
开发语言·python
羊羊小栈8 小时前
基于多时间序列模型和大语言模型的航海轨迹预测分析预警系统( LSTM、GRU、Transformer、CNN-LSTM、DLinear)
人工智能·语言模型·cnn·gru·毕业设计·lstm·transformer
chatexcel9 小时前
AI PPT 教程:基于旅游生活场景的提示词设计与生成流程
人工智能·ppt
寻道码路10 小时前
LangChain4j Java AI 应用开发实战(四):提示词工程进阶 - 模板化与结构化 Prompt 设计
java·人工智能·ai·prompt·aigc