gemma4 实现ASR语音识别

1.下载模型

python 复制代码
https://huggingface.co/google/gemma-4-E2B-it
# https://huggingface.co/google/gemma-4-E4B-it

2.下载测试音频文件

来自paddlespeech(https://github.com/PaddlePaddle/PaddleSpeech/tree/develop

python 复制代码
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/zh.wav
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav

3.代码

python 复制代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Gemma4 ASR (audio -> text)

参考官方思路:Gemma 4 E2B/E4B 通过 chat template 输入 audio。
默认音频路径使用你给的: /Users/wanghq/workspce/zh.wav

安装依赖:
  pip install -U torch accelerate transformers librosa soundfile

示例:
  python gemma4_asr.py
  python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav --model google/gemma-4-E4B-it
  python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav /Users/wanghq/workspce/en.wav
"""

import argparse
import sys
import time
from pathlib import Path

import librosa
import torch
from transformers import AutoModelForMultimodalLM, AutoProcessor


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Gemma4 ASR CLI")
    parser.add_argument(
        "--audio",
        type=str,
        nargs="+",
        default=["/Users/wanghq/workspce/en.wav"],
        help="本地音频路径,支持 mp3/wav 等(由 librosa 解码),可一次传多个",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="google/gemma-4-E2B-it",
        help="Gemma4 模型名或本地模型目录,推荐 E2B/E4B 做 ASR",
    )
    parser.add_argument("--max-new-tokens", type=int, default=128, help="最大生成 token")
    parser.add_argument(
        "--prompt",
        type=str,
        default=(
            "请将下面音频转写为原始语言文本。"
            "只输出转写结果,不要解释,不要换行。"
            "数字请用阿拉伯数字。"
        ),
        help="ASR 指令",
    )
    parser.add_argument(
        "--save",
        type=str,
        default="",
        help="可选:将转写结果保存到此文件路径",
    )
    return parser


def validate_audio(path: str) -> tuple[str, float]:
    audio_path = Path(path)
    if not audio_path.exists():
        raise FileNotFoundError(f"音频文件不存在: {audio_path}")

    # 仅读取时长做前置检查,真正解码交给 processor/librosa。
    duration = librosa.get_duration(path=str(audio_path))
    if duration > 30.0:
        print(f"警告: 音频时长 {duration:.2f}s,Gemma4 官方建议单段不超过 30s。")
    return str(audio_path.resolve()), duration


def load_model_and_processor(model_name: str):
    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForMultimodalLM.from_pretrained(
        model_name,
        dtype="auto",
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    return model, processor


def transcribe(audio_path: str, prompt: str, model, processor, max_new_tokens: int) -> str:
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "audio", "audio": audio_path},
            ],
        }
    ]

    model_inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    model_inputs = model_inputs.to(model.device, dtype=model.dtype)

    with torch.no_grad():
        outputs = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    prompt_len = model_inputs["input_ids"].shape[-1]
    new_tokens = outputs[0][prompt_len:]
    text = processor.decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return text.strip()


def main():
    args = build_parser().parse_args()

    audio_items = []
    for audio in args.audio:
        try:
            audio_path, duration = validate_audio(audio)
            audio_items.append((audio_path, duration))
        except Exception as exc:
            print(f"音频检查失败: {exc}")
            print("提示: 你当前目录有 /Users/wanghq/workspce/zh.wav 和 /Users/wanghq/workspce/en.wav。")
            sys.exit(1)

    try:
        load_start = time.perf_counter()
        model, processor = load_model_and_processor(args.model)
        load_elapsed = time.perf_counter() - load_start
    except Exception as exc:
        print(f"模型加载失败: {exc}")
        print("请确认你有模型访问权限,并已安装依赖。")
        sys.exit(1)

    print(f"模型: {args.model}")
    print("音频文件:")
    for audio_path, duration in audio_items:
        print(f"  - {audio_path} ({duration:.2f}s)")
    print(f"模型加载耗时: {load_elapsed:.2f}s")
    print("开始转写...\n")

    results = []
    for idx, (audio_path, _) in enumerate(audio_items, start=1):
        try:
            transcribe_start = time.perf_counter()
            transcript = transcribe(
                audio_path=audio_path,
                prompt=args.prompt,
                model=model,
                processor=processor,
                max_new_tokens=args.max_new_tokens,
            )
            transcribe_elapsed = time.perf_counter() - transcribe_start
        except torch.cuda.OutOfMemoryError:
            print(f"[{idx}] 转写失败: 显存不足。可尝试 E2B、小一点 max-new-tokens,或更短音频。")
            continue
        except Exception as exc:
            print(f"[{idx}] 转写失败: {exc}")
            continue

        results.append((audio_path, transcript, transcribe_elapsed))
        print(f"[{idx}] 转写文件: {audio_path}")
        print("转写结果:")
        print(transcript)
        print(f"音频转文字耗时: {transcribe_elapsed:.2f}s\n")

    if args.save:
        out = Path(args.save)
        out.parent.mkdir(parents=True, exist_ok=True)
        lines = []
        for audio_path, transcript, elapsed in results:
            lines.append(f"[file] {audio_path}")
            lines.append(f"[time] {elapsed:.2f}s")
            lines.append(transcript)
            lines.append("")
        out.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
        print(f"\n已保存到: {out.resolve()}")


if __name__ == "__main__":
    main()
相关推荐
南子北游17 小时前
Python学习(基础语法1)
开发语言·python·学习
AI木马人17 小时前
13.【多租户架构实战】如何让一个AI系统同时服务多个用户且数据完全隔离?(完整设计方案)
人工智能·架构
sjsjsbbsbsn17 小时前
大模型核心知识总结
java·人工智能·后端
步辞17 小时前
Redis如何利用LFU算法优化缓存命中率
jvm·数据库·python
forEverPlume18 小时前
golang如何实现日志按级别过滤_golang日志按级别过滤实现教程
jvm·数据库·python
qq_4112624218 小时前
四博 AI 双目智能音箱方案:把“会说话的音箱”升级成“会表达、会感知、会控制”的 AI 终端
人工智能·智能音箱
努力努力再努力FFF18 小时前
跨境电商运营想用AI优化广告和选品,该从哪里开始学?
人工智能
薛定猫AI18 小时前
【深度解析】Claude Code Skills 工作流:用知识图谱、设计规范与 Agent 工具链提升 AI 编程效率
人工智能·知识图谱·设计规范
AI自动化工坊18 小时前
Cloudflare Project Think技术实践:零成本AI Agent部署架构深度解析
人工智能·架构·agent·cloudflare
IT_陈寒19 小时前
JavaScript里这个隐式类型转换的坑,我终于爬出来了
前端·人工智能·后端