本地部署 whisper-medusa

本地部署 whisper-medusa

  • [0. 引言](#0. 引言)
  • [1. 本地部署](#1. 本地部署)
    • [1-1. 创建虚拟环境](#1-1. 创建虚拟环境)
    • [1-2. 克隆代码](#1-2. 克隆代码)
    • [1-3. 安装依赖模块](#1-3. 安装依赖模块)
    • [1-4. 创建 Web UI](#1-4. 创建 Web UI)
    • [1-5. 启动 Web UI](#1-5. 启动 Web UI)
    • [1-5. 访问 Web UI](#1-5. 访问 Web UI)

0. 引言

Whisper 是一种用于语音转录和翻译的高级编码器-解码器模型,通过编码和解码阶段处理音频。鉴于其尺寸大和推理速度慢,人们提出了各种优化策略(例如 Faster-Whisper 和 Speculative Decoding)来提高性能。我们的 Medusa 模型建立在 Whisper 的基础上,通过每次迭代预测多个标记,这显着提高了速度,同时 WER 略有下降。我们在 LibriSpeech 数据集上训练和评估我们的模型,与普通 Whisper 模型相比,展示了强大的性能速度改进和同等准确度。

1. 本地部署

1-1. 创建虚拟环境

conda create -n whisper-medusa python=3.11 -y
conda activate whisper-medusa

1-2. 克隆代码

git clone https://github.com/aiola-lab/whisper-medusa.git
cd whisper-medusa

1-3. 安装依赖模块

pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
pip install -e .
conda install matplotlib
pip install gradio

1-4. 创建 Web UI

# webui.py
import torch
import torchaudio
import gradio as gr
from whisper_medusa import WhisperMedusaModel
from transformers import WhisperProcessor

# Load model and processor
model_name = "aiola/whisper-medusa-v1"
model = WhisperMedusaModel.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Constants
SAMPLING_RATE = 16000

def transcribe_audio(audio_file, language):
    # Load and preprocess audio
    input_speech, sr = torchaudio.load(audio_file)
    if input_speech.shape[0] > 1:  # If stereo, average the channels
        input_speech = input_speech.mean(dim=0, keepdim=True)
    if sr != SAMPLING_RATE:
        input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
    
    # Process input
    input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
    input_features = input_features.to(device)
    
    # Generate transcription
    model_output = model.generate(
        input_features,
        language=language,
    )
    predict_ids = model_output[0]
    transcription = processor.decode(predict_ids, skip_special_tokens=True)
    
    return transcription

# Define Gradio interface
iface = gr.Interface(
    fn=transcribe_audio,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio"),
        gr.Dropdown(["en", "zh", "ja"], label="Select Language", value="en")
    ],
    outputs="text",
    title="Audio Transcription with Whisper Medusa",
    description="Upload an audio file and select the language to transcribe the audio to text."
)

# Launch the interface
iface.launch()

1-5. 启动 Web UI

python webui.py

1-5. 访问 Web UI

使用浏览器访问 http://localhost:7860

相关推荐
再不会python就不礼貌了2 分钟前
Ollama 0.4 发布!支持 Llama 3.2 Vision,实现多模态 RAG
人工智能·学习·机器学习·ai·开源·产品经理·llama
大大大反派5 分钟前
ONLYOFFICE 8.2深度测评:集成PDF编辑、数据可视化与AI功能的强大办公套件
人工智能·信息可视化·pdf
DK2215120 分钟前
机器学习系列-----主成分分析(PCA)
人工智能·算法·机器学习
SmallBambooCode2 小时前
【人工智能】阿里云PAI平台DSW实例一键安装Python脚本
linux·人工智能·python·阿里云·debian·脚本·模型训练
顾京2 小时前
基于扩散模型的表单插补
人工智能·深度学习·算法
NoneCoder2 小时前
AI时代IDE解析
ide·人工智能
狂奔solar2 小时前
yelp数据集上试验SVD,SVDPP,PMF,NMF 推荐算法
人工智能·机器学习·推荐算法
武子康2 小时前
大数据-216 数据挖掘 机器学习理论 - KMeans 基于轮廓系数来选择 n_clusters
大数据·人工智能·机器学习·数据挖掘·回归·scikit-learn·kmeans
liupenglove2 小时前
ElasticSearch向量检索技术方案介绍
大数据·人工智能·深度学习·elasticsearch·搜索引擎·自动驾驶
黄焖鸡能干四碗3 小时前
【系统文档】系统安全保障措施,安全运营保障,系统应急预案,系统验收相关资料(word原件)
大数据·人工智能·需求分析·软件需求·规格说明书