1.下载模型
python
https://huggingface.co/google/gemma-4-E2B-it
# https://huggingface.co/google/gemma-4-E4B-it
2.下载测试音频文件
来自paddlespeech(https://github.com/PaddlePaddle/PaddleSpeech/tree/develop)
python
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/zh.wav
wget -c https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav
3.代码
python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Gemma4 ASR (audio -> text)
参考官方思路:Gemma 4 E2B/E4B 通过 chat template 输入 audio。
默认音频路径使用你给的: /Users/wanghq/workspce/zh.wav
安装依赖:
pip install -U torch accelerate transformers librosa soundfile
示例:
python gemma4_asr.py
python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav --model google/gemma-4-E4B-it
python gemma4_asr.py --audio /Users/wanghq/workspce/zh.wav /Users/wanghq/workspce/en.wav
"""
import argparse
import sys
import time
from pathlib import Path
import librosa
import torch
from transformers import AutoModelForMultimodalLM, AutoProcessor
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Gemma4 ASR CLI")
parser.add_argument(
"--audio",
type=str,
nargs="+",
default=["/Users/wanghq/workspce/en.wav"],
help="本地音频路径,支持 mp3/wav 等(由 librosa 解码),可一次传多个",
)
parser.add_argument(
"--model",
type=str,
default="google/gemma-4-E2B-it",
help="Gemma4 模型名或本地模型目录,推荐 E2B/E4B 做 ASR",
)
parser.add_argument("--max-new-tokens", type=int, default=128, help="最大生成 token")
parser.add_argument(
"--prompt",
type=str,
default=(
"请将下面音频转写为原始语言文本。"
"只输出转写结果,不要解释,不要换行。"
"数字请用阿拉伯数字。"
),
help="ASR 指令",
)
parser.add_argument(
"--save",
type=str,
default="",
help="可选:将转写结果保存到此文件路径",
)
return parser
def validate_audio(path: str) -> tuple[str, float]:
audio_path = Path(path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
# 仅读取时长做前置检查,真正解码交给 processor/librosa。
duration = librosa.get_duration(path=str(audio_path))
if duration > 30.0:
print(f"警告: 音频时长 {duration:.2f}s,Gemma4 官方建议单段不超过 30s。")
return str(audio_path.resolve()), duration
def load_model_and_processor(model_name: str):
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMultimodalLM.from_pretrained(
model_name,
dtype="auto",
device_map="auto",
trust_remote_code=True,
)
model.eval()
return model, processor
def transcribe(audio_path: str, prompt: str, model, processor, max_new_tokens: int) -> str:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "audio", "audio": audio_path},
],
}
]
model_inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
model_inputs = model_inputs.to(model.device, dtype=model.dtype)
with torch.no_grad():
outputs = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
prompt_len = model_inputs["input_ids"].shape[-1]
new_tokens = outputs[0][prompt_len:]
text = processor.decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return text.strip()
def main():
args = build_parser().parse_args()
audio_items = []
for audio in args.audio:
try:
audio_path, duration = validate_audio(audio)
audio_items.append((audio_path, duration))
except Exception as exc:
print(f"音频检查失败: {exc}")
print("提示: 你当前目录有 /Users/wanghq/workspce/zh.wav 和 /Users/wanghq/workspce/en.wav。")
sys.exit(1)
try:
load_start = time.perf_counter()
model, processor = load_model_and_processor(args.model)
load_elapsed = time.perf_counter() - load_start
except Exception as exc:
print(f"模型加载失败: {exc}")
print("请确认你有模型访问权限,并已安装依赖。")
sys.exit(1)
print(f"模型: {args.model}")
print("音频文件:")
for audio_path, duration in audio_items:
print(f" - {audio_path} ({duration:.2f}s)")
print(f"模型加载耗时: {load_elapsed:.2f}s")
print("开始转写...\n")
results = []
for idx, (audio_path, _) in enumerate(audio_items, start=1):
try:
transcribe_start = time.perf_counter()
transcript = transcribe(
audio_path=audio_path,
prompt=args.prompt,
model=model,
processor=processor,
max_new_tokens=args.max_new_tokens,
)
transcribe_elapsed = time.perf_counter() - transcribe_start
except torch.cuda.OutOfMemoryError:
print(f"[{idx}] 转写失败: 显存不足。可尝试 E2B、小一点 max-new-tokens,或更短音频。")
continue
except Exception as exc:
print(f"[{idx}] 转写失败: {exc}")
continue
results.append((audio_path, transcript, transcribe_elapsed))
print(f"[{idx}] 转写文件: {audio_path}")
print("转写结果:")
print(transcript)
print(f"音频转文字耗时: {transcribe_elapsed:.2f}s\n")
if args.save:
out = Path(args.save)
out.parent.mkdir(parents=True, exist_ok=True)
lines = []
for audio_path, transcript, elapsed in results:
lines.append(f"[file] {audio_path}")
lines.append(f"[time] {elapsed:.2f}s")
lines.append(transcript)
lines.append("")
out.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
print(f"\n已保存到: {out.resolve()}")
if __name__ == "__main__":
main()