gemma4 实现ASR语音识别

1.下载模型

python 复制代码
https://huggingface.co/google/gemma-4-E2B-it
# https://huggingface.co/google/gemma-4-E4B-it

2.下载测试音频文件

来自paddlespeech(https://github.com/PaddlePaddle/PaddleSpeech/tree/develop

python 复制代码
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/zh.wav
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav

3.代码

python 复制代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Gemma4 ASR (audio -> text)

参考官方思路:Gemma 4 E2B/E4B 通过 chat template 输入 audio。
默认音频路径使用你给的: /Users/wanghq/workspce/zh.wav

安装依赖:
  pip install -U torch accelerate transformers librosa soundfile

示例:
  python gemma4_asr.py
  python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav --model google/gemma-4-E4B-it
  python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav /Users/wanghq/workspce/en.wav
"""

import argparse
import sys
import time
from pathlib import Path

import librosa
import torch
from transformers import AutoModelForMultimodalLM, AutoProcessor


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Gemma4 ASR CLI")
    parser.add_argument(
        "--audio",
        type=str,
        nargs="+",
        default=["/Users/wanghq/workspce/en.wav"],
        help="本地音频路径,支持 mp3/wav 等(由 librosa 解码),可一次传多个",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="google/gemma-4-E2B-it",
        help="Gemma4 模型名或本地模型目录,推荐 E2B/E4B 做 ASR",
    )
    parser.add_argument("--max-new-tokens", type=int, default=128, help="最大生成 token")
    parser.add_argument(
        "--prompt",
        type=str,
        default=(
            "请将下面音频转写为原始语言文本。"
            "只输出转写结果,不要解释,不要换行。"
            "数字请用阿拉伯数字。"
        ),
        help="ASR 指令",
    )
    parser.add_argument(
        "--save",
        type=str,
        default="",
        help="可选:将转写结果保存到此文件路径",
    )
    return parser


def validate_audio(path: str) -> tuple[str, float]:
    audio_path = Path(path)
    if not audio_path.exists():
        raise FileNotFoundError(f"音频文件不存在: {audio_path}")

    # 仅读取时长做前置检查,真正解码交给 processor/librosa。
    duration = librosa.get_duration(path=str(audio_path))
    if duration > 30.0:
        print(f"警告: 音频时长 {duration:.2f}s,Gemma4 官方建议单段不超过 30s。")
    return str(audio_path.resolve()), duration


def load_model_and_processor(model_name: str):
    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForMultimodalLM.from_pretrained(
        model_name,
        dtype="auto",
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    return model, processor


def transcribe(audio_path: str, prompt: str, model, processor, max_new_tokens: int) -> str:
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "audio", "audio": audio_path},
            ],
        }
    ]

    model_inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    model_inputs = model_inputs.to(model.device, dtype=model.dtype)

    with torch.no_grad():
        outputs = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    prompt_len = model_inputs["input_ids"].shape[-1]
    new_tokens = outputs[0][prompt_len:]
    text = processor.decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return text.strip()


def main():
    args = build_parser().parse_args()

    audio_items = []
    for audio in args.audio:
        try:
            audio_path, duration = validate_audio(audio)
            audio_items.append((audio_path, duration))
        except Exception as exc:
            print(f"音频检查失败: {exc}")
            print("提示: 你当前目录有 /Users/wanghq/workspce/zh.wav 和 /Users/wanghq/workspce/en.wav。")
            sys.exit(1)

    try:
        load_start = time.perf_counter()
        model, processor = load_model_and_processor(args.model)
        load_elapsed = time.perf_counter() - load_start
    except Exception as exc:
        print(f"模型加载失败: {exc}")
        print("请确认你有模型访问权限,并已安装依赖。")
        sys.exit(1)

    print(f"模型: {args.model}")
    print("音频文件:")
    for audio_path, duration in audio_items:
        print(f"  - {audio_path} ({duration:.2f}s)")
    print(f"模型加载耗时: {load_elapsed:.2f}s")
    print("开始转写...\n")

    results = []
    for idx, (audio_path, _) in enumerate(audio_items, start=1):
        try:
            transcribe_start = time.perf_counter()
            transcript = transcribe(
                audio_path=audio_path,
                prompt=args.prompt,
                model=model,
                processor=processor,
                max_new_tokens=args.max_new_tokens,
            )
            transcribe_elapsed = time.perf_counter() - transcribe_start
        except torch.cuda.OutOfMemoryError:
            print(f"[{idx}] 转写失败: 显存不足。可尝试 E2B、小一点 max-new-tokens,或更短音频。")
            continue
        except Exception as exc:
            print(f"[{idx}] 转写失败: {exc}")
            continue

        results.append((audio_path, transcript, transcribe_elapsed))
        print(f"[{idx}] 转写文件: {audio_path}")
        print("转写结果:")
        print(transcript)
        print(f"音频转文字耗时: {transcribe_elapsed:.2f}s\n")

    if args.save:
        out = Path(args.save)
        out.parent.mkdir(parents=True, exist_ok=True)
        lines = []
        for audio_path, transcript, elapsed in results:
            lines.append(f"[file] {audio_path}")
            lines.append(f"[time] {elapsed:.2f}s")
            lines.append(transcript)
            lines.append("")
        out.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
        print(f"\n已保存到: {out.resolve()}")


if __name__ == "__main__":
    main()
相关推荐
来自远方的老作者2 小时前
第8章 流程控制-8.2 选择结构
开发语言·python·选择结构
kaico20182 小时前
python常用标准库
开发语言·python
TTGGGFF2 小时前
SnapTranslate 2.0:轻量级全场景划词翻译——添加生词本以及生词本复习AI助手功能!
python·划词翻译·git开源
reasonsummer2 小时前
【教学类-160-01】20260408 AI视频培训-练习1“豆包AI视频”
人工智能·音视频
杜子不疼.2 小时前
Python + Selenium + AI 智能爬虫:自动识别反爬与数据提取
人工智能·python·selenium
Elastic 中国社区官方博客2 小时前
Elasticsearch:语义搜索,现在默认支持多语言
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
财经资讯数据_灵砚智能2 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年4月8日
大数据·人工智能·python·信息可视化·自然语言处理
枫叶林FYL2 小时前
【自然语言处理 NLP】多模态与具身智能:视觉-语言预训练技术手册
人工智能·机器学习·自然语言处理
AI获客新方案@柯望望2 小时前
GEO并非SEO的AI适配版 生成式引擎优化核心术语说明
人工智能·geo·生成式引擎优化