gemma4 实现ASR语音识别

1.下载模型

python 复制代码
https://huggingface.co/google/gemma-4-E2B-it
# https://huggingface.co/google/gemma-4-E4B-it

2.下载测试音频文件

来自paddlespeech(https://github.com/PaddlePaddle/PaddleSpeech/tree/develop

python 复制代码
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/zh.wav
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav

3.代码

python 复制代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Gemma4 ASR (audio -> text)

参考官方思路:Gemma 4 E2B/E4B 通过 chat template 输入 audio。
默认音频路径使用你给的: /Users/wanghq/workspce/zh.wav

安装依赖:
  pip install -U torch accelerate transformers librosa soundfile

示例:
  python gemma4_asr.py
  python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav --model google/gemma-4-E4B-it
  python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav /Users/wanghq/workspce/en.wav
"""

import argparse
import sys
import time
from pathlib import Path

import librosa
import torch
from transformers import AutoModelForMultimodalLM, AutoProcessor


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Gemma4 ASR CLI")
    parser.add_argument(
        "--audio",
        type=str,
        nargs="+",
        default=["/Users/wanghq/workspce/en.wav"],
        help="本地音频路径,支持 mp3/wav 等(由 librosa 解码),可一次传多个",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="google/gemma-4-E2B-it",
        help="Gemma4 模型名或本地模型目录,推荐 E2B/E4B 做 ASR",
    )
    parser.add_argument("--max-new-tokens", type=int, default=128, help="最大生成 token")
    parser.add_argument(
        "--prompt",
        type=str,
        default=(
            "请将下面音频转写为原始语言文本。"
            "只输出转写结果,不要解释,不要换行。"
            "数字请用阿拉伯数字。"
        ),
        help="ASR 指令",
    )
    parser.add_argument(
        "--save",
        type=str,
        default="",
        help="可选:将转写结果保存到此文件路径",
    )
    return parser


def validate_audio(path: str) -> tuple[str, float]:
    audio_path = Path(path)
    if not audio_path.exists():
        raise FileNotFoundError(f"音频文件不存在: {audio_path}")

    # 仅读取时长做前置检查,真正解码交给 processor/librosa。
    duration = librosa.get_duration(path=str(audio_path))
    if duration > 30.0:
        print(f"警告: 音频时长 {duration:.2f}s,Gemma4 官方建议单段不超过 30s。")
    return str(audio_path.resolve()), duration


def load_model_and_processor(model_name: str):
    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForMultimodalLM.from_pretrained(
        model_name,
        dtype="auto",
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    return model, processor


def transcribe(audio_path: str, prompt: str, model, processor, max_new_tokens: int) -> str:
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "audio", "audio": audio_path},
            ],
        }
    ]

    model_inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    model_inputs = model_inputs.to(model.device, dtype=model.dtype)

    with torch.no_grad():
        outputs = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    prompt_len = model_inputs["input_ids"].shape[-1]
    new_tokens = outputs[0][prompt_len:]
    text = processor.decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return text.strip()


def main():
    args = build_parser().parse_args()

    audio_items = []
    for audio in args.audio:
        try:
            audio_path, duration = validate_audio(audio)
            audio_items.append((audio_path, duration))
        except Exception as exc:
            print(f"音频检查失败: {exc}")
            print("提示: 你当前目录有 /Users/wanghq/workspce/zh.wav 和 /Users/wanghq/workspce/en.wav。")
            sys.exit(1)

    try:
        load_start = time.perf_counter()
        model, processor = load_model_and_processor(args.model)
        load_elapsed = time.perf_counter() - load_start
    except Exception as exc:
        print(f"模型加载失败: {exc}")
        print("请确认你有模型访问权限,并已安装依赖。")
        sys.exit(1)

    print(f"模型: {args.model}")
    print("音频文件:")
    for audio_path, duration in audio_items:
        print(f"  - {audio_path} ({duration:.2f}s)")
    print(f"模型加载耗时: {load_elapsed:.2f}s")
    print("开始转写...\n")

    results = []
    for idx, (audio_path, _) in enumerate(audio_items, start=1):
        try:
            transcribe_start = time.perf_counter()
            transcript = transcribe(
                audio_path=audio_path,
                prompt=args.prompt,
                model=model,
                processor=processor,
                max_new_tokens=args.max_new_tokens,
            )
            transcribe_elapsed = time.perf_counter() - transcribe_start
        except torch.cuda.OutOfMemoryError:
            print(f"[{idx}] 转写失败: 显存不足。可尝试 E2B、小一点 max-new-tokens,或更短音频。")
            continue
        except Exception as exc:
            print(f"[{idx}] 转写失败: {exc}")
            continue

        results.append((audio_path, transcript, transcribe_elapsed))
        print(f"[{idx}] 转写文件: {audio_path}")
        print("转写结果:")
        print(transcript)
        print(f"音频转文字耗时: {transcribe_elapsed:.2f}s\n")

    if args.save:
        out = Path(args.save)
        out.parent.mkdir(parents=True, exist_ok=True)
        lines = []
        for audio_path, transcript, elapsed in results:
            lines.append(f"[file] {audio_path}")
            lines.append(f"[time] {elapsed:.2f}s")
            lines.append(transcript)
            lines.append("")
        out.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
        print(f"\n已保存到: {out.resolve()}")


if __name__ == "__main__":
    main()
相关推荐
Hesionberger几秒前
LeetCode105:前序中序构建二叉树(三解法)
java·数据结构·python·算法·leetcode·深度优先
jianwuhuang821 分钟前
智谱清言怎么导出pdf
人工智能·chatgpt·pdf·豆包·deepseek·ai导出鸭
数智前线2 分钟前
腾讯云融合创新产品矩阵全面升级,首次发布专有云版“龙虾”
大数据·人工智能
Chase_______2 分钟前
【Java杂项】为什么 long 可以自动转 float?宽化基本类型转换与精度丢失详解
java·开发语言·python
青云计划3 分钟前
给 AI 写一份老厨师的菜谱:从传统文档到 Skill 知识体系
人工智能
invicinble3 分钟前
java数组相关的信息量
java·开发语言·python
小江的记录本4 分钟前
【Java基础】Java 8-21新特性 :JDK17:密封类、模式匹配、Record类(附《思维导图》+《面试高频考点清单》)
java·数据结构·后端·python·mysql·面试·职场和发展
Luminbox紫创测控4 分钟前
基于环境舱的新能源汽车三高试验方法与热响应评估
大数据·人工智能·测试工具·汽车·安全性测试·测试标准
码小猿的CPP工坊4 分钟前
AI时代C++软件开发工程师的思考
c++·人工智能
敲上瘾5 分钟前
LangChain 消息机制与提示词模板指南
大数据·python·langchain