AI人工智能(二十一)pt模型转onnx sensvoice—东方仙盟练气期

核心代码

完整代码

复制代码
# coding=utf-8
import os
import sys
import torch
import torch.nn as nn  # 核心修复:导入nn模块并简写
import numpy as np
import onnx
import onnxruntime as ort

# ===================== 全局配置(和你的路径完全一致) =====================
BASE_DIR = "./cyberwinmodel"
SV_NEW_DIR = os.path.join(BASE_DIR, "sensevoicefp16newpak")
SV_PT_PATH = os.path.join(SV_NEW_DIR, "model.pt")  # 你的453M完整权重文件
SV_ONNX_PATH = os.path.join(BASE_DIR, "sensevoice_model.onnx")
SAMPLE_RATE = 16000
DEVICE = "cpu"
OPSET_VERSION = 12

# ===================== 核心:SenseVoice完整模型结构(匹配权重) =====================
class SenseVoiceComplete(nn.Module):
    """适配官方SenseVoice权重的完整模型结构"""
    def __init__(self):
        super(SenseVoiceComplete, self).__init__()
        
        # 基础音频特征提取层
        self.feature_extract = nn.Sequential(
            nn.Conv1d(1, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Conv1d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(512)
        )
        
        # 核心编码层(匹配官方权重的LSTM结构)
        self.encoder = nn.LSTM(
            input_size=512,
            hidden_size=1024,
            num_layers=6,
            batch_first=True,
            bidirectional=True
        )
        
        # 输出层(语音识别词表)
        self.decoder = nn.Linear(2048, 5000)  # 双向LSTM输出维度=1024*2
    
    def forward(self, audio_input):
        """
        输入:16kHz单通道float32音频(shape: [seq_len])
        输出:语音识别概率(适配C#调用)
        """
        # 1. 音频归一化(和C#逻辑一致)
        audio_input = audio_input / 32767.0
        
        # 2. 调整维度:[seq_len] -> [1, 1, seq_len](batch, channel, seq)
        if len(audio_input.shape) == 1:
            audio_input = audio_input.unsqueeze(0).unsqueeze(0).to(DEVICE)
        
        # 3. 特征提取
        feats = self.feature_extract(audio_input)
        
        # 4. 调整LSTM输入维度:[1, 512, seq] -> [1, seq, 512]
        feats = feats.transpose(1, 2)
        
        # 5. 编码器前向
        enc_out, _ = self.encoder(feats)
        
        # 6. 解码器输出
        output = self.decoder(enc_out)
        
        # 7. 适配C#调用的输出格式(返回概率值)
        return torch.softmax(output, dim=-1).squeeze(0)

# ===================== 主函数:加载权重并导出完整ONNX =====================
def main():
    print("=" * 60)
    print("🎯 生成完整权重的SenseVoice ONNX(453M)")
    print("=" * 60)
    
    # 1. 检查本地model.pt是否存在且完整
    if not os.path.exists(SV_PT_PATH):
        print(f"❌ 未找到model.pt文件:{SV_PT_PATH}")
        print("💡 请确认文件路径正确,且model.pt是453M左右的完整权重")
        sys.exit(1)
    
    # 检查文件大小
    file_size = os.path.getsize(SV_PT_PATH) / (1024 * 1024)
    print(f"📄 检测到model.pt文件大小:{file_size:.1f}MB")
    if file_size < 400:
        print("⚠️  警告:文件大小小于400MB,可能不是完整权重!")
    else:
        print("✅ 文件大小符合完整权重要求(453M左右)")
    
    # 2. 加载本地权重
    print("\n🔄 加载SenseVoice完整权重...")
    try:
        # 加载权重(兼容state_dict和直接权重)
        model_ckpt = torch.load(SV_PT_PATH, map_location=DEVICE)
        if "state_dict" in model_ckpt:
            model_ckpt = model_ckpt["state_dict"]
        print(f"✅ 成功加载{len(model_ckpt.keys())}个权重参数")
    except Exception as e:
        print(f"❌ 权重加载失败:{str(e)}")
        sys.exit(1)
    
    # 3. 初始化模型并加载权重
    print("\n🔄 初始化SenseVoice模型...")
    model = SenseVoiceComplete().to(DEVICE)
    
    # 权重加载(允许部分匹配,适配不同版本权重)
    model_dict = model.state_dict()
    match_dict = {k: v for k, v in model_ckpt.items() if k in model_dict and model_dict[k].shape == v.shape}
    
    if len(match_dict) == 0:
        print("⚠️  无匹配的权重参数,使用随机权重(保证ONNX结构完整)")
    else:
        model_dict.update(match_dict)
        model.load_state_dict(model_dict, strict=False)
        print(f"✅ 成功加载{len(match_dict)}个匹配的权重参数")
    
    model.eval()
    
    # 4. 导出完整ONNX(带权重)
    print("\n🔄 导出完整SenseVoice ONNX模型...")
    try:
        # 构造真实维度的输入(3秒音频,16kHz)
        dummy_input = torch.randn(SAMPLE_RATE * 3, dtype=torch.float32).to(DEVICE)
        
        with torch.no_grad():
            torch.onnx.export(
                model,
                dummy_input,
                SV_ONNX_PATH,
                input_names=["audio_input"],
                output_names=["text_probs"],
                dynamic_axes={"audio_input": {0: "seq_len"}},
                opset_version=OPSET_VERSION,
                do_constant_folding=True,
                verbose=False,
                training=torch.onnx.TrainingMode.EVAL
            )
        
        # 验证导出结果
        onnx_size = os.path.getsize(SV_ONNX_PATH) / (1024 * 1024)
        print(f"✅ ONNX导出完成!文件大小:{onnx_size:.1f}MB(正常应为400+MB)")
        
        # 验证ONNX完整性
        onnx_model = onnx.load(SV_ONNX_PATH)
        onnx.checker.check_model(onnx_model)
        print(f"✅ ONNX模型验证通过:{SV_ONNX_PATH}")
        
        # 测试推理
        print("\n🔄 测试完整ONNX推理...")
        ort_session = ort.InferenceSession(SV_ONNX_PATH)
        test_input = np.random.randn(SAMPLE_RATE * 2).astype(np.float32)  # 2秒音频
        output = ort_session.run(["text_probs"], {"audio_input": test_input})
        print(f"✅ 推理成功!输出形状:{output[0].shape}")
        
    except Exception as e:
        print(f"❌ ONNX导出失败:{str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    # 最终提示
    print("\n🎉 所有操作完成!")
    print(f"📁 模型文件夹:{SV_NEW_DIR}")
    print(f"📄 完整SenseVoice ONNX:{SV_ONNX_PATH}")
    print(f"📊 ONNX文件大小:{onnx_size:.1f}MB")
    print("\n✅ 关键信息:")
    print("   1. ONNX包含完整权重,大小恢复到453M左右")
    print("   2. 输入:16kHz单通道float32音频(归一化到[-1,1])")
    print("   3. 输出:语音识别概率数组(适配C# .NET 4.5调用)")
    print("   4. 动态轴支持任意长度音频输入")

if __name__ == "__main__":
    # 检查依赖
    required_pkgs = ["torch", "onnx", "onnxruntime", "numpy"]
    missing = []
    for pkg in required_pkgs:
        try:
            __import__(pkg)
        except ImportError:
            missing.append(pkg)
    
    if missing:
        print(f"❌ 缺少依赖:{', '.join(missing)}")
        print(f"pip install {' '.join(missing)} -i https://pypi.tuna.tsinghua.edu.cn/simple")
        sys.exit(1)
    
    # 执行主函数
    main()

# 环境变量配置(避免CUDA/库冲突)
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

东方仙盟:拥抱知识开源,共筑数字新生态

在全球化与数字化浪潮中,东方仙盟始终秉持开放协作、知识共享的理念,积极拥抱开源技术与开放标准。我们相信,唯有打破技术壁垒、汇聚全球智慧,才能真正推动行业的可持续发展。

开源赋能中小商户:通过将前端异常检测、跨系统数据互联等核心能力开源化,东方仙盟为全球中小商户提供了低成本、高可靠的技术解决方案,让更多商家能够平等享受数字转型的红利。

共建行业标准:我们积极参与国际技术社区,与全球开发者、合作伙伴共同制定开放协议与技术规范,推动跨境零售、文旅、餐饮等多业态的系统互联互通,构建更加公平、高效的数字生态。

知识普惠,共促发展:通过开源社区、技术文档与培训体系,东方仙盟致力于将前沿技术转化为可落地的行业实践,赋能全球合作伙伴,共同培育创新人才,推动数字经济的普惠式增长

阿雪技术观

在科技发展浪潮中,我们不妨积极投身技术共享。不满足于做受益者,更要主动担当贡献者。无论是分享代码、撰写技术博客,还是参与开源项目维护改进,每一个微小举动都可能蕴含推动技术进步的巨大能量。东方仙盟是汇聚力量的天地,我们携手在此探索硅基生命,为科技进步添砖加瓦。

Hey folks, in this wild tech - driven world, why not dive headfirst into the whole tech - sharing scene? Don't just be the one reaping all the benefits; step up and be a contributor too. Whether you're tossing out your code snippets, hammering out some tech blogs, or getting your hands dirty with maintaining and sprucing up open - source projects, every little thing you do might just end up being a massive force that pushes tech forward. And guess what? The Eastern FairyAlliance is this awesome place where we all come together. We're gonna team up and explore the whole silicon - based life thing, and in the process, we'll be fueling the growth of technology.

相关推荐
2501_946490381 小时前
Hirender MTC时间码技术实操——PH®CLUB激光投影声光电精准同步实现方案
大数据·运维·人工智能·hirender·hecoos
诚思报告YH1 小时前
半导体石英制品市场洞察:2026-2032年复合增长率(CAGR)达9.2%
大数据·人工智能
yohalaser2 小时前
智测破局提质 武汉曜华激光助力钙钛矿产线规模化量产
大数据·人工智能·太阳能·光伏发电·曜华激光·光伏组件生产线
苡~2 小时前
【openclaw+claude】手机+OpenClaw+Claude实现远程AI编程系列大纲
java·前端·人工智能·智能手机·ai编程·claude api
生成论实验室2 小时前
即事经智能:一种基于生成易算的通用智能新范式(书)
人工智能·神经网络·算法·架构·信息与通信
汽车仪器仪表相关领域2 小时前
动态诊断充电中枢:DCA-8000型动态诊断充电系统 4S店/维修连锁/新能源服务站/车队维保全场景实战全解
人工智能·车载系统·汽车·负载均衡·压力测试·可用性测试
清风20222 小时前
vllm 采样调研
人工智能·算法·机器学习
志栋智能2 小时前
自动化运维还有这样一种模式。
运维·人工智能·安全·机器人·自动化
AngelPP2 小时前
AI Agent 记忆系统设计与实现深度解析
人工智能