[机器人感知] 基于 YAMNet TFLite + FastAPI 的环境声音识别系统(支持中英双语标签)

随着智能机器人和智能物联网设备的普及,**环境声识别(Audio Classification)**成为人机交互中不可或缺的能力。

[主要识别的目标是环境的声音,比如打字敲键盘的声音,狗狗的叫声,窗外的鸟叫声,还有各种各样我们身边的环境声]

本项目实现一个完整的环境声音识别系统:

✔ 后端:FastAPI + TFLite 推理

✔ 模型:Google YAMNet(支持 521 类环境声音分类)

✔ 前端:Web 上传音频 → 实时返回识别结果

✔ 标签支持中英双语映射

✔ 适配嵌入式设备(如机器人)

【🚀🚀🚀如果你对人工智能的学习有兴趣可以看看我的其他博客,对新手很友好!🚀🚀🚀】

【🚀🚀🚀本猿定期无偿分享学习成果,欢迎关注一起学习!🚀🚀🚀】

效果展示👇

传入鸟叫声:

📌 一、YAMNet 原理简介

YAMNet 全称 Yet Another Mobile Network ,基于 MobileNetV1 轻量网络结构

主要用于环境声音分类,特性如下:

能力 说明
类别数 521 个(来自 AudioSet 数据集)
输入格式 16kHz 单声道 WAV
推理速度 支持 CPU / ARM / 移动端 / Android
典型应用 机器人环境感知、IoT 安全感知、智能音箱

核心流程:

bash 复制代码
graph LR
A[音频输入] --> B[重采样 16kHz]
B --> C[分帧 0.975s]
C --> D[TFLite 模型]
D --> E[分类概率 521 维向量]
E --> F[Top-K 标签 + 中英双语映射]

🛠 二、项目目录结构

bash 复制代码
audio_robot/
│
├── main.py              # FastAPI TFLite 后端
├── model/               # 模型目录
│    ├── model.tflite
│    ├── labels.txt
│    └── labels_zh.json  # 英->中映射
└── frontend/
     └── index.html      # 测试页面

📦 三、Conda 环境搭建

建议使用独立环境 🍰

bash 复制代码
conda create -n EnvSound python=3.10 -y
conda activate EnvSound

可选:安装轻量版 TFLite Runtime(建议服务器/嵌入式)

📌 Windows CPU 版示例:

bash 复制代码
pip install https://github.com/google-coral/pycoral/releases/download/release-frogfish/tflite_runtime-2.14.0-cp39-cp39-win_amd64.whl

如未安装成功,将回退使用 tensorflow 内置推理:

bash 复制代码
pip install tensorflow==2.13.0

🔥 四、模型文件准备

模型文件很难找,博主找了好大一圈,在kaggle找到了

bash 复制代码
https://storage.googleapis.com/kaggle-models-data/633/766/bundle/archive.tar.gz?GoogleAccessId=web-data@kaggle-161607.iam.gserviceaccount.com&Expires=1764792950&Signature=n8EWxJrX48zpv8FV5Q82wUiMQkvguShM7wtjNxzEQJVzyAMAx05GMB7WlscPRMwiKVFmsnMumtGJpAbHfUoGz0pUXZKdvC3wx%2FVYSYPEMkjNf1hkNckk7EfvsomEMPcZIwbGr6o26lGb3njGeAlgLv0AeVWcEWj9%2BAVcnJq2lIG6PmEFpV7uPz%2BOeyg5Fh%2Fe4zqumrMfnJcGH2EjhhU%2FPs4xxmTc3CvhNEjk8jCWom%2FS4AH%2BdLCdzA1cITazQYem5D6cHjqetnLhgl1Dqs0UbXh53TCVI9VJkOEEjEaQKbgG1hgLWEKLRfItzvfu3uo5YGFhaBZpbPTdOq%2BAjcHswA%3D%3D&response-content-disposition=attachment%3B+filename%3Dyamnet-tflite-classification-tflite-v1.tar.gz

标签文件自动生成

如果存在 labels_zh.json 会自动加载双语标签

否则仅显示英文标签

⚡ 五、FastAPI 后端运行

完整后端代码如下:

python 复制代码
# main.py
# -*- coding: utf-8 -*-
"""
FastAPI 后端 - 接收前端上传音频(WAV / FLAC / OGG 等),调用 YAMNet TFLite 模型返回分类结果。

模型 / 文件约定:
    - 模型文件:model/model.tflite
    - 英文标签:model/labels.txt            (521 行,一个英文标签一行)
    - 中英映射:model/labels_zh.json       (key: 英文标签, value: "English 中文翻译")

示例 labels_zh.json 内容:
{
    "Speech": "Speech 人声",
    "Child speech, kid speaking": "Child speech 儿童说话",
    "Conversation": "Conversation 对话",
    ...
}

返回结果中:
    - top1.label_en / item.label_en : 英文
    - top1.label_zh / item.label_zh : 中文
    - top1.label / item.label       : 仍然保持英文(兼容旧前端)
"""

import io
import os
import json
import zipfile
from typing import List, Dict, Optional

from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

import numpy as np
import soundfile as sf
import librosa

# ================= 解释器后端选择 =================

BACKEND = None  # "tflite-runtime" 或 "tensorflow"

try:
    # 首选:tflite-runtime(轻量,适合嵌入式 / 服务器)
    from tflite_runtime.interpreter import Interpreter as TFLiteInterpreter
    BACKEND = "tflite-runtime"
except Exception:
    # 回退到 tensorflow 自带的 tf.lite.Interpreter
    try:
        import tensorflow as tf
        TFLiteInterpreter = tf.lite.Interpreter
        BACKEND = "tensorflow"
    except Exception:
        TFLiteInterpreter = None
        BACKEND = None

print(f"[INFO] TFLite backend: {BACKEND}")

# ================= 配置 =================

MODEL_PATH = os.environ.get("MODEL_PATH", "model/model.tflite")
LABEL_PATH = os.environ.get("LABEL_PATH", "model/labels.txt")
LABEL_ZH_PATH = os.environ.get("LABEL_ZH_PATH", "model/labels_zh.json")

SAMPLE_RATE = 16000
YAMNET_FRAME_SECONDS = 0.975
NUM_SAMPLES = int(round(YAMNET_FRAME_SECONDS * SAMPLE_RATE))  # 15600

# 一个简单的"可信阈值"
CONFIDENCE_THRESHOLD = float(os.environ.get("YAMNET_CONF_THRESH", "0.3"))

# 下面这几个变量会在初始化时赋值
labels: List[str] = []
labels_zh_map: Dict[str, str] = {}
interpreter = None
_input_index: int = -1
_output_index: int = -1
_input_shape = None


# ================= 标签相关 =================

def load_labels_from_file(path: str) -> List[str]:
    """从 labels.txt 读取标签(英文)。"""
    with open(path, "r", encoding="utf-8") as f:
        lines = [l.strip() for l in f.readlines() if l.strip()]
    print(f"[INFO] 从 {path} 读取标签 {len(lines)} 条")
    return lines


def load_labels_from_tflite(model_path: str) -> List[str]:
    """
    按 Kaggle 示例方式,把 TFLite 当 zip 打开,读取 yamnet_label_list.txt。
    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"模型文件不存在: {model_path}")

    with zipfile.ZipFile(model_path, "r") as zf:
        with zf.open("yamnet_label_list.txt") as f:
            lines = [l.decode("utf-8").strip() for l in f.readlines() if l.strip()]

    print(f"[INFO] 从 {model_path} 内部的 yamnet_label_list.txt 读取标签 {len(lines)} 条")
    return lines


def ensure_labels(model_path: str, label_path: str) -> List[str]:
    """
    若 labels.txt 存在,直接用;
    不存在就从 model.tflite 里提取 yamnet_label_list.txt 并写出一份 labels.txt。
    """
    if os.path.exists(label_path):
        return load_labels_from_file(label_path)

    labels = load_labels_from_tflite(model_path)
    os.makedirs(os.path.dirname(label_path), exist_ok=True)
    with open(label_path, "w", encoding="utf-8") as f:
        f.write("\n".join(labels))
    print(f"[INFO] 已将标签写出到 {label_path}")
    return labels


def load_labels_zh(path: str) -> Dict[str, str]:
    """
    读取 labels_zh.json:英文标签 -> "English 中文翻译"。
    文件缺失 / 异常时返回空 dict。
    """
    if not os.path.exists(path):
        print(f"[WARN] 未找到 labels_zh.json: {path},只返回英文标签。")
        return {}
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        if not isinstance(data, dict):
            print(f"[WARN] labels_zh.json 内容不是字典,将忽略。")
            return {}
        print(f"[INFO] 从 {path} 读取 {len(data)} 条中英映射。")
        # 全部转成 str,防止 key/value 里出现别的类型
        return {str(k): str(v) for k, v in data.items()}
    except Exception as e:
        print(f"[WARN] 解析 labels_zh.json 失败: {e}")
        return {}


def split_en_zh(en: str, value: Optional[str]) -> Dict[str, str]:
    """
    把 (英文标签, 映射值) 转成 {'label_en': ..., 'label_zh': ...}

    映射值示例:
        "Speech 人声"
        "Child speech 儿童说话"

    策略:
        - 若映射以英文标签开头,则去掉前面的英文 + 分隔符,剩下部分当中文;
        - 否则,整个值当中文;
        - 如果没有映射,就中文=英文。
    """
    label_en = en
    label_zh = ""

    if isinstance(value, str) and value:
        v = value.strip()
        if v.startswith(en):
            rest = v[len(en):].strip(" ::、-")
            label_zh = rest or en
        else:
            label_zh = v
    else:
        label_zh = en

    return {"label_en": label_en, "label_zh": label_zh}


# ================= 音频处理相关 =================

def load_interpreter(model_path: str):
    """加载 TFLite 模型。"""
    if TFLiteInterpreter is None or BACKEND is None:
        raise RuntimeError("没有可用的 TFLite 解释器。请安装 tflite-runtime 或 tensorflow。")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"模型文件不存在: {model_path}")
    print(f"[INFO] 使用后端: {BACKEND}, 加载模型: {model_path}")
    interp = TFLiteInterpreter(model_path=model_path)
    return interp


def read_audio_from_bytes(file_bytes: bytes):
    """
    使用 soundfile 从 bytes 读取音频(支持 wav/flac/ogg/...,但不一定支持 webm)。
    返回 float32 ndarray (mono) 和采样率 sr。
    """
    bio = io.BytesIO(file_bytes)
    try:
        data, sr = sf.read(bio, dtype="float32")
    except Exception as e:
        raise ValueError(f"无法解析音频: {e}")

    # 多通道转单通道
    if data.ndim > 1:
        data = np.mean(data, axis=1)

    if data.size == 0:
        raise ValueError("音频数据为空。")

    data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
    data = np.clip(data, -1.0, 1.0).astype(np.float32)
    return data, sr


def resample_to_16k(wave: np.ndarray, sr: int, target_sr: int = SAMPLE_RATE) -> np.ndarray:
    """重采样到 16kHz。"""
    if sr <= 0:
        raise ValueError(f"非法采样率: {sr}")
    if sr == target_sr:
        return wave
    return librosa.resample(wave, orig_sr=sr, target_sr=target_sr)


def make_frames_975ms(wave_16k: np.ndarray) -> np.ndarray:
    """
    按 0.975s 一帧切分音频(不重叠),不足一帧的部分做 0 填充。
    返回 shape: (num_frames, NUM_SAMPLES)
    """
    if len(wave_16k) <= 0:
        raise ValueError("重采样后的音频为空。")

    frames = []
    hop = NUM_SAMPLES

    if len(wave_16k) <= NUM_SAMPLES:
        pad_len = NUM_SAMPLES - len(wave_16k)
        frame = np.pad(wave_16k, (0, pad_len), mode="constant")
        frames.append(frame)
    else:
        num_full = len(wave_16k) // hop
        for i in range(num_full):
            start = i * hop
            end = start + NUM_SAMPLES
            frame = wave_16k[start:end]
            if len(frame) < NUM_SAMPLES:
                frame = np.pad(frame, (0, NUM_SAMPLES - len(frame)), mode="constant")
            frames.append(frame)

        # 尾部一帧(如果不想要尾巴可以去掉这段)
        tail_start = num_full * hop
        tail = wave_16k[tail_start:]
        if len(tail) > 0:
            if len(tail) < NUM_SAMPLES:
                tail = np.pad(tail, (0, NUM_SAMPLES - len(tail)), mode="constant")
            else:
                tail = tail[:NUM_SAMPLES]
            frames.append(tail)

    frames = np.stack(frames, axis=0).astype(np.float32)
    return np.clip(frames, -1.0, 1.0)


# ================= 初始化:加载模型 & 标签 =================

labels = ensure_labels(MODEL_PATH, LABEL_PATH)
labels_zh_map = load_labels_zh(LABEL_ZH_PATH)
interpreter = load_interpreter(MODEL_PATH)

# allocate tensors
interpreter.allocate_tensors()
_input_details = interpreter.get_input_details()
_output_details = interpreter.get_output_details()

_input_index = _input_details[0]["index"]
_output_index = _output_details[0]["index"]
_input_shape = tuple(_input_details[0]["shape"])

print(f"[INFO] 模型 input shape: {_input_shape}")
print(f"[INFO] 模型 output shape: {_output_details[0]['shape']} (一般为 [1, 521])")


# ================= 推理函数 =================

def run_tflite_on_frame(frame_1d: np.ndarray) -> np.ndarray:
    """
    对单帧 (15600,) 做推理,返回 (521,) 概率分数。
    """
    x = np.asarray(frame_1d, dtype=np.float32)
    if x.ndim != 1 or x.shape[0] != NUM_SAMPLES:
        raise ValueError(f"frame 形状必须是 ({NUM_SAMPLES},),当前是 {x.shape}")

    # 适配 [15600] 或 [1,15600] 或 [1,15600,1] 等情况
    x = x.reshape(_input_shape).astype(np.float32)

    interpreter.set_tensor(_input_index, x)
    interpreter.invoke()
    out = interpreter.get_tensor(_output_index)

    out = np.squeeze(out)  # 压成一维 (521,)
    out = np.clip(out, 0.0, 1.0).astype(np.float32)
    return out


def run_tflite_multiframe(frames: np.ndarray) -> np.ndarray:
    """
    对多帧音频做推理,frames: (num_frames, 15600)
    返回 clip 级别的 (521,) 分数(按帧平均)。
    """
    num_frames = frames.shape[0]
    all_scores = []

    for i in range(num_frames):
        scores_i = run_tflite_on_frame(frames[i])
        all_scores.append(scores_i)

    all_scores = np.stack(all_scores, axis=0)  # (num_frames, 521)
    clip_scores = np.mean(all_scores, axis=0)
    return clip_scores, all_scores


def build_label_item(idx: int, score: float) -> Dict:
    """构造返回给前端的单条 top-K 结果(带中英)。"""
    if 0 <= idx < len(labels):
        label_en = labels[idx]
    else:
        label_en = str(idx)

    pair = split_en_zh(label_en, labels_zh_map.get(label_en))

    return {
        "index": idx,
        # 兼容旧前端:label 仍然是英文
        "label": pair["label_en"],
        "label_en": pair["label_en"],
        "label_zh": pair["label_zh"],
        "score": float(score),
        "is_confident": float(score) >= CONFIDENCE_THRESHOLD,
    }


# ================= FastAPI =================

app = FastAPI(
    title="EnvSound YAMNet Classifier",
    version="1.2",
    description="Upload an audio clip and get environment sound classification (YAMNet, EN+ZH labels)."
)

# CORS(生产环境可以收紧)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/health")
async def health():
    return {"status": "ok", "backend": BACKEND, "num_classes": len(labels)}


@app.post("/predict_audio")
async def predict_audio(file: UploadFile = File(...)):
    """
    接收上传音频文件并返回 top-K 预测(中英双语标签)。
    """
    content = await file.read()
    if not content:
        raise HTTPException(status_code=400, detail="上传文件为空")

    try:
        waveform, sr = read_audio_from_bytes(content)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

    duration_sec = float(len(waveform) / sr) if sr > 0 else None

    # 1. 重采样到 16k
    try:
        wave_16k = resample_to_16k(waveform, sr, target_sr=SAMPLE_RATE)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

    # 2. 切成多个 0.975s 帧
    try:
        frames = make_frames_975ms(wave_16k)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))

    num_frames = int(frames.shape[0])

    # 3. 多帧推理 + 帧平均
    try:
        clip_scores, frame_scores = run_tflite_multiframe(frames)
    except Exception as e:
        print(f"[ERROR] 模型推理失败: {e}")
        raise HTTPException(status_code=500, detail=f"模型推理失败: {e}")

    scores = clip_scores.astype(float)
    top_idx = int(np.argmax(scores))
    top_score = float(scores[top_idx])

    # top1(带中英)
    top1_item = build_label_item(top_idx, top_score)

    # top-K 列表
    top_k = 5
    topk_indices = list(np.argsort(-scores)[:top_k])
    topk = [build_label_item(int(i), float(scores[int(i)])) for i in topk_indices]

    return {
        "top1": top1_item,
        "topk": topk,
        "threshold": CONFIDENCE_THRESHOLD,
        "meta": {
            "backend": BACKEND,
            "sample_rate": SAMPLE_RATE,
            "input_sample_rate": sr,
            "duration_sec": duration_sec,
            "num_frames": num_frames,
            "frame_seconds": YAMNET_FRAME_SECONDS,
            "num_classes": len(labels),
        },
    }


# ================= 启动 =================

if __name__ == "__main__":
    # Windows 下 reload=True 会启动两个进程,打印会重复一次是正常现象
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

健康检查:

📌 GET → http://localhost:8000/health

返回示例:

bash 复制代码
{"status":"ok","backend":"tensorflow","num_classes":521}

上传音频识别:

📌 POST → http://localhost:8000/predict_audio

💻 六、前端测试页面

下面是完整测试前端代码(html)

html 复制代码
<!DOCTYPE html>
<html lang="zh-CN">
<head>
  <meta charset="UTF-8">
  <title>EnvSound 环境声识别测试</title>
  <meta name="viewport" content="width=device-width,initial-scale=1">
  <style>
    :root {
      --bg: #0f172a;
      --bg-soft: #020617;
      --card: #020617;
      --card-soft: #020617;
      --accent: #4f46e5;
      --accent-soft: rgba(79,70,229,0.14);
      --accent-soft-2: rgba(56,189,248,0.12);
      --border: #1f2937;
      --text: #e5e7eb;
      --text-soft: #9ca3af;
      --danger: #b91c1c;
      --radius-lg: 18px;
      --radius-md: 10px;
      --radius-pill: 999px;
      --shadow-card: 0 24px 80px rgba(0,0,0,0.55);
    }

    * {
      box-sizing: border-box;
    }

    body {
      margin: 0;
      min-height: 100vh;
      font-family: system-ui, -apple-system, BlinkMacSystemFont, "SF Pro Text",
        "Segoe UI", sans-serif;
      background:
        radial-gradient(circle at top left, rgba(56,189,248,0.16), transparent 55%),
        radial-gradient(circle at top right, rgba(129,140,248,0.18), transparent 55%),
        radial-gradient(circle at bottom, rgba(52,211,153,0.12), transparent 60%),
        linear-gradient(160deg,#020617 0%,#020617 55%,#020617 100%);
      color: var(--text);
      display: flex;
      align-items: flex-start;
      justify-content: center;
      padding: 32px 16px 40px;
    }

    .shell {
      width: 100%;
      max-width: 880px;
    }

    .app-header {
      display: flex;
      align-items: center;
      justify-content: space-between;
      margin-bottom: 18px;
      gap: 12px;
    }

    .title-block {
      display: flex;
      align-items: center;
      gap: 12px;
    }

    .title-icon {
      width: 34px;
      height: 34px;
      border-radius: 18px;
      background: radial-gradient(circle at 20% 20%, #a5b4fc 0, #4f46e5 34%, #0f172a 100%);
      display: inline-flex;
      align-items: center;
      justify-content: center;
      font-size: 18px;
      color: #e5e7eb;
      box-shadow: 0 16px 40px rgba(79,70,229,0.55);
    }

    h1 {
      margin: 0;
      font-size: 22px;
      letter-spacing: 0.02em;
    }

    .subtitle {
      margin-top: 3px;
      font-size: 13px;
      color: var(--text-soft);
    }

    .badge {
      font-size: 11px;
      padding: 6px 10px;
      border-radius: var(--radius-pill);
      border: 1px solid rgba(148,163,184,0.4);
      color: #cbd5f5;
      display: inline-flex;
      align-items: center;
      gap: 6px;
      background: rgba(15,23,42,0.7);
      backdrop-filter: blur(12px);
      white-space: nowrap;
    }

    .badge-dot {
      width: 7px;
      height: 7px;
      border-radius: 999px;
      background: #22c55e;
      box-shadow: 0 0 0 4px rgba(34,197,94,0.28);
    }

    .card {
      background: radial-gradient(circle at top left, rgba(56,189,248,0.16), transparent 55%),
                  radial-gradient(circle at bottom right, rgba(79,70,229,0.16), transparent 55%),
                  var(--card);
      border-radius: var(--radius-lg);
      padding: 22px 22px 24px;
      border: 1px solid rgba(148,163,184,0.25);
      box-shadow: var(--shadow-card);
      display: grid;
      grid-template-columns: minmax(0, 1.05fr) minmax(0, 1fr);
      gap: 20px;
    }

    @media (max-width: 800px) {
      .card {
        grid-template-columns: minmax(0, 1fr);
        padding: 18px 18px 20px;
      }
    }

    .section-title {
      font-size: 13px;
      text-transform: uppercase;
      letter-spacing: 0.12em;
      color: #9ca3af;
      margin-bottom: 10px;
    }

    .field-group {
      margin-bottom: 14px;
    }

    label {
      display: block;
      font-size: 13px;
      margin-bottom: 6px;
      color: #d1d5db;
    }

    .input-text {
      width: 100%;
      padding: 8px 10px;
      border-radius: 9px;
      border: 1px solid rgba(148,163,184,0.45);
      background: radial-gradient(circle at top left, rgba(30,64,175,0.26), transparent 60%),
                  rgba(15,23,42,0.9);
      color: var(--text);
      font-size: 13px;
      outline: none;
      transition: border-color 0.15s ease, box-shadow 0.15s ease, background 0.15s ease;
    }

    .input-text::placeholder {
      color: #64748b;
    }

    .input-text:focus {
      border-color: #60a5fa;
      box-shadow: 0 0 0 1px rgba(59,130,246,0.7);
      background: radial-gradient(circle at top left, rgba(96,165,250,0.22), transparent 60%),
                  rgba(15,23,42,0.94);
    }

    input[type="file"] {
      width: 100%;
      font-size: 13px;
      color: var(--text-soft);
    }

    .hint {
      margin-top: 4px;
      font-size: 11px;
      color: var(--text-soft);
    }

    .btn-row {
      margin-top: 10px;
      display: flex;
      gap: 10px;
      align-items: center;
      flex-wrap: wrap;
    }

    button {
      border: none;
      border-radius: 999px;
      padding: 8px 18px;
      font-size: 13px;
      cursor: pointer;
      display: inline-flex;
      align-items: center;
      justify-content: center;
      gap: 6px;
      transition: transform 0.08s ease, box-shadow 0.12s ease, background 0.12s ease, opacity 0.1s ease;
      white-space: nowrap;
    }

    button.primary {
      background: linear-gradient(135deg,#4f46e5,#22c55e);
      color: #f9fafb;
      box-shadow: 0 14px 30px rgba(79,70,229,0.5);
    }

    button.primary:hover:not(:disabled) {
      transform: translateY(-1px);
      box-shadow: 0 20px 44px rgba(79,70,229,0.65);
    }

    button.primary:active:not(:disabled) {
      transform: translateY(0);
      box-shadow: 0 8px 20px rgba(79,70,229,0.45);
    }

    button.primary:disabled {
      opacity: 0.55;
      cursor: not-allowed;
      box-shadow: none;
    }

    button.secondary {
      background: rgba(15,23,42,0.9);
      border: 1px solid rgba(148,163,184,0.4);
      color: #e5e7eb;
    }

    button.secondary:hover {
      background: rgba(30,64,175,0.35);
      border-color: #60a5fa;
    }

    .btn-icon {
      font-size: 14px;
    }

    .status {
      margin-top: 10px;
      font-size: 12px;
      color: var(--text-soft);
      min-height: 16px;
    }

    .status.error {
      color: var(--danger);
    }

    /* 结果卡片 */
    .result-card {
      background: rgba(15,23,42,0.92);
      border-radius: var(--radius-lg);
      border: 1px solid rgba(148,163,184,0.35);
      padding: 14px 14px 16px;
      display: none;
    }

    .result-header {
      display: flex;
      align-items: center;
      justify-content: space-between;
      margin-bottom: 10px;
      gap: 8px;
    }

    .result-title {
      font-size: 14px;
      font-weight: 600;
      display: flex;
      align-items: center;
      gap: 6px;
    }

    .pill {
      font-size: 11px;
      padding: 4px 9px;
      border-radius: var(--radius-pill);
      background: var(--accent-soft);
      border: 1px solid rgba(129,140,248,0.75);
      color: #c7d2fe;
      display: inline-flex;
      align-items: center;
      gap: 4px;
    }

    .pill-dot {
      width: 6px;
      height: 6px;
      border-radius: 999px;
      background: #22c55e;
    }

    .top1-box {
      padding: 10px 11px;
      border-radius: var(--radius-md);
      border: 1px solid rgba(96,165,250,0.5);
      background: radial-gradient(circle at top left, rgba(96,165,250,0.38), transparent 60%),
                  rgba(15,23,42,0.96);
      margin-bottom: 10px;
      display: grid;
      grid-template-columns: minmax(0, 1.5fr) minmax(0, 1fr);
      gap: 8px 18px;
    }

    @media (max-width: 800px) {
      .top1-box {
        grid-template-columns: minmax(0, 1fr);
      }
    }

    .top1-labels {
      font-size: 13px;
    }

    .top1-label-line {
      margin-bottom: 4px;
    }

    .top1-label-tag {
      display: inline-flex;
      align-items: center;
      gap: 6px;
      padding: 2px 7px;
      border-radius: var(--radius-pill);
      background: rgba(15,23,42,0.9);
      border: 1px solid rgba(148,163,184,0.4);
      font-size: 11px;
      color: #e5e7eb;
    }

    .lang-tag {
      font-size: 10px;
      padding: 1px 6px;
      border-radius: var(--radius-pill);
      background: rgba(39,39,42,0.9);
      color: #9ca3af;
      text-transform: uppercase;
    }

    .top1-extra {
      font-size: 12px;
      color: var(--text-soft);
      display: flex;
      flex-direction: column;
      gap: 3px;
      justify-content: center;
    }

    .top1-score {
      font-size: 13px;
      color: #fbbf24;
    }

    .meta {
      margin-top: 8px;
      font-size: 11px;
      color: var(--text-soft);
    }

    .meta span {
      margin-right: 10px;
    }

    table {
      width: 100%;
      border-collapse: collapse;
      font-size: 12px;
      margin-top: 6px;
    }

    th, td {
      padding: 6px 6px;
      border-bottom: 1px solid rgba(31,41,55,0.9);
      text-align: left;
    }

    th {
      background: rgba(15,23,42,0.98);
      font-weight: 500;
      color: #9ca3af;
      position: sticky;
      top: 0;
      z-index: 1;
    }

    tbody tr:nth-child(odd) {
      background: rgba(15,23,42,0.96);
    }

    tbody tr:nth-child(even) {
      background: rgba(15,23,42,0.9);
    }

    .score-cell {
      font-variant-numeric: tabular-nums;
    }

    .json-block {
      margin-top: 10px;
      font-size: 11px;
    }

    details {
      background: rgba(15,23,42,0.98);
      border-radius: var(--radius-md);
      border: 1px solid rgba(55,65,81,0.85);
      padding: 6px 8px 6px;
    }

    summary {
      cursor: pointer;
      list-style: none;
      outline: none;
      font-size: 11px;
      color: var(--text-soft);
    }

    summary::marker,
    summary::-webkit-details-marker {
      display: none;
    }

    pre {
      margin: 8px 0 0;
      padding: 8px 8px;
      border-radius: 8px;
      background: #020617;
      color: #e5e7eb;
      font-size: 11px;
      overflow-x: auto;
      max-height: 260px;
      border: 1px solid rgba(31,41,55,0.95);
    }

    code {
      font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
    }

    .status-chip {
      font-size: 11px;
      padding: 2px 7px;
      border-radius: 999px;
      background: rgba(15,23,42,0.95);
      border: 1px solid rgba(75,85,99,0.8);
      color: #9ca3af;
    }

    .status-chip strong {
      color: #e5e7eb;
    }
  </style>
</head>
<body>
<div class="shell">
  <header class="app-header">
    <div class="title-block">
      <div class="title-icon">🎧</div>
      <div>
        <h1>EnvSound · 环境声识别</h1>
        <div class="subtitle">基于 YAMNet TFLite 的环境声音分类 · 中英双语标签</div>
      </div>
    </div>
    <div class="badge">
      <span class="badge-dot"></span>
      <span>FastAPI · TFLite · 16kHz</span>
    </div>
  </header>

  <main class="card">
    <!-- 左侧:输入区域 -->
    <section>
      <div class="section-title">Input · 输入</div>

      <div class="field-group">
        <label for="backend-url">后端地址(/predict_audio)</label>
        <input id="backend-url" class="input-text" type="text"
               value="http://localhost:8000/predict_audio"
               placeholder="例如:http://localhost:8000/predict_audio">
        <div class="hint">如果你换了端口或部署到服务器,请修改这里的地址。</div>
      </div>

      <div class="field-group">
        <label for="audio-file">选择音频文件</label>
        <input id="audio-file" type="file" accept="audio/*">
        <div class="hint">
          建议使用 1 秒左右的 16kHz 单声道 WAV;也支持部分 FLAC / OGG 等
          <code>soundfile</code> 可解析格式。
        </div>
      </div>

      <div class="btn-row">
        <button id="btn-send" class="primary">
          <span class="btn-icon">🚀</span>
          <span>上传并识别</span>
        </button>
        <button id="btn-clear" class="secondary">
          <span class="btn-icon">🧹</span>
          <span>清空结果</span>
        </button>
      </div>

      <div id="status" class="status"></div>
    </section>

    <!-- 右侧:结果区域 -->
    <section>
      <div class="section-title">Result · 结果</div>

      <div id="result" class="result-card">
        <div class="result-header">
          <div class="result-title">
            <span>识别结果</span>
          </div>
          <div id="latency-chip" class="status-chip" style="display:none;"></div>
        </div>

        <div class="top1-box">
          <div class="top1-labels">
            <div class="top1-label-line">
              <span class="top1-label-tag">
                <span class="lang-tag">EN</span>
                <span id="top1-label-en"></span>
              </span>
            </div>
            <div class="top1-label-line">
              <span class="top1-label-tag">
                <span class="lang-tag">ZH</span>
                <span id="top1-label-zh"></span>
              </span>
            </div>
          </div>
          <div class="top1-extra">
            <div>置信度 / Confidence:
              <span class="top1-score" id="top1-score"></span>
            </div>
            <div id="top1-flag"></div>
          </div>
        </div>

        <div>
          <div class="pill">
            <span class="pill-dot"></span>
            <span>Top-K 类别详情</span>
          </div>
          <table>
            <thead>
            <tr>
              <th>#</th>
              <th>Label (EN)</th>
              <th>标签 (ZH)</th>
              <th>Score</th>
            </tr>
            </thead>
            <tbody id="topk-body"></tbody>
          </table>
        </div>

        <div class="meta" id="meta-info"></div>

        <div class="json-block">
          <details>
            <summary>查看原始 JSON 返回数据</summary>
            <pre id="raw-json"><code></code></pre>
          </details>
        </div>
      </div>
    </section>
  </main>
</div>

<script>
  const btnSend = document.getElementById('btn-send');
  const btnClear = document.getElementById('btn-clear');
  const fileInput = document.getElementById('audio-file');
  const backendInput = document.getElementById('backend-url');
  const statusEl = document.getElementById('status');
  const resultCard = document.getElementById('result');
  const latencyChip = document.getElementById('latency-chip');

  const top1LabelEnEl = document.getElementById('top1-label-en');
  const top1LabelZhEl = document.getElementById('top1-label-zh');
  const top1ScoreEl = document.getElementById('top1-score');
  const top1FlagEl = document.getElementById('top1-flag');
  const topkBodyEl = document.getElementById('topk-body');
  const rawJsonEl = document.getElementById('raw-json').querySelector('code');
  const metaInfoEl = document.getElementById('meta-info');

  function setStatus(text, isError = false) {
    statusEl.textContent = text;
    statusEl.className = 'status' + (isError ? ' error' : '');
  }

  function clearResult() {
    resultCard.style.display = 'none';
    latencyChip.style.display = 'none';
    top1LabelEnEl.textContent = '';
    top1LabelZhEl.textContent = '';
    top1ScoreEl.textContent = '';
    top1FlagEl.textContent = '';
    topkBodyEl.innerHTML = '';
    rawJsonEl.textContent = '';
    metaInfoEl.textContent = '';
    setStatus('');
  }

  btnClear.addEventListener('click', () => {
    clearResult();
    fileInput.value = '';
  });

  btnSend.addEventListener('click', async () => {
    clearResult();

    const file = fileInput.files[0];
    if (!file) {
      setStatus('请先选择一个音频文件。', true);
      return;
    }

    const url = backendInput.value.trim();
    if (!url) {
      setStatus('后端地址不能为空。', true);
      return;
    }

    const formData = new FormData();
    formData.append('file', file); // 与 FastAPI 参数名一致:file

    btnSend.disabled = true;
    btnSend.textContent = '识别中...';
    setStatus('正在上传并识别,请稍等...');

    const t0 = performance.now();

    try {
      const resp = await fetch(url, {
        method: 'POST',
        body: formData
      });

      const t1 = performance.now();
      const costMs = (t1 - t0).toFixed(0);

      if (!resp.ok) {
        let errText = '';
        try {
          errText = await resp.text();
        } catch {
          errText = '未知错误';
        }
        setStatus(`请求失败,HTTP ${resp.status}:${errText}`, true);
      } else {
        const data = await resp.json();
        console.log('Response:', data);
        renderResult(data, costMs);
        setStatus(`识别完成 ✅ (耗时约 ${costMs} ms)`);
      }
    } catch (e) {
      console.error(e);
      setStatus('请求出错:' + e.message, true);
    } finally {
      btnSend.disabled = false;
      btnSend.textContent = '上传并识别';
    }
  });

  function renderResult(data, costMs) {
    if (!data) return;

    // top1
    if (data.top1) {
      const t = data.top1;
      top1LabelEnEl.textContent = t.label_en ?? t.label ?? '';
      top1LabelZhEl.textContent = t.label_zh ?? '';

      if (typeof t.score === 'number') {
        const p = (t.score * 100).toFixed(2);
        top1ScoreEl.textContent = `${p}%`;
      } else {
        top1ScoreEl.textContent = '';
      }

      if (t.is_confident === true) {
        top1FlagEl.textContent = '超过阈值,结果较可信 ✓';
      } else if (typeof t.score === 'number') {
        top1FlagEl.textContent = '低于阈值,结果仅供参考';
      } else {
        top1FlagEl.textContent = '';
      }
    }

    // top-K 列表
    topkBodyEl.innerHTML = '';
    if (Array.isArray(data.topk)) {
      data.topk.forEach((item, idx) => {
        const tr = document.createElement('tr');
        const tdIdx = document.createElement('td');
        const tdLabelEn = document.createElement('td');
        const tdLabelZh = document.createElement('td');
        const tdScore = document.createElement('td');

        tdIdx.textContent = String(idx + 1);
        tdLabelEn.textContent = item.label_en ?? item.label ?? '';
        tdLabelZh.textContent = item.label_zh ?? '';
        tdScore.className = 'score-cell';

        if (typeof item.score === 'number') {
          tdScore.textContent = (item.score * 100).toFixed(2) + '%';
        } else {
          tdScore.textContent = '';
        }

        tr.appendChild(tdIdx);
        tr.appendChild(tdLabelEn);
        tr.appendChild(tdLabelZh);
        tr.appendChild(tdScore);
        topkBodyEl.appendChild(tr);
      });
    }

    // meta 信息
    const parts = [];
    if (data.meta) {
      if (data.meta.backend) {
        parts.push(`后端: ${data.meta.backend}`);
      }
      if (typeof data.meta.duration_sec === 'number') {
        parts.push(`音频时长: ${data.meta.duration_sec.toFixed(2)} s`);
      }
      if (typeof data.meta.sample_rate === 'number') {
        parts.push(`模型采样率: ${data.meta.sample_rate} Hz`);
      }
      if (typeof data.meta.input_sample_rate === 'number') {
        parts.push(`原始采样率: ${data.meta.input_sample_rate} Hz`);
      }
      if (typeof data.meta.num_classes === 'number') {
        parts.push(`类别数: ${data.meta.num_classes}`);
      }
      if (typeof data.meta.num_frames === 'number') {
        parts.push(`帧数: ${data.meta.num_frames} × ${data.meta.frame_seconds ?? 0.975}s`);
      }
    }
    metaInfoEl.textContent = parts.join(' | ');

    // latency chip
    if (costMs !== undefined) {
      latencyChip.style.display = 'inline-flex';
      latencyChip.innerHTML = `<strong>${costMs} ms</strong> · 端到端耗时`;
    } else {
      latencyChip.style.display = 'none';
    }

    // 原始 JSON
    rawJsonEl.textContent = JSON.stringify(data, null, 2);

    resultCard.style.display = 'block';
  }
</script>
</body>
</html>

🔮 博主整理不易,如果对你有帮助,可以点个免费的赞吗?感谢感谢!

相关推荐
小小测试开发2 小时前
FastAPI 完全入门指南:从环境搭建到实战部署
python·fastapi
生而为虫1 天前
31.Python语言进阶
python·scrapy·django·flask·fastapi·pygame·tornado
rising start2 天前
三、FastAPI :POST 请求、用户接口设计与 Requests 测试
python·网络协议·http·fastapi
生而为虫2 天前
30.正则表达式的应用
python·正则表达式·django·flask·fastapi·tornado
m***11902 天前
开源模型应用落地-FastAPI-助力模型交互-进阶篇-中间件(四)
开源·交互·fastapi
dubochao_xinxi3 天前
fastapi 接收10g 以上的文件 pandas 读取超过 10GB 的 Excel 文件
excel·pandas·fastapi
rising start3 天前
一、FastAPI入门
python·fastapi·端口
laufing3 天前
fastapi 基础介绍
fastapi·高性能·python web
龙腾AI白云4 天前
【具身智能】
fastapi