音频分类模型笔记

目录

audio-mamba-aum

[Small ImageNet](#Small ImageNet)

[kimia_infer 音频理解](#kimia_infer 音频理解)

音频分类:

saurabhati/DASS_medium_AudioSet_50.2

ast

AudioClassification-Pytorch没有模型,自己训练:


|------------|-----------|
| OmniVec2.0 | 0.558 |

|---------|-------|
| OmniVec | 0.548 |

|--------|-------|
| EquiAV | 0.546 |

|------------------------------|-------|
| MAViL (Audio-Visual, single) | 0.533 |

|-------------------------|-------|
| PaSST-S / ConvNeXt-Tiny | 0.471 |

|------------------------------|-------|
| PSLA (Ensemble EfficientNet) | 0.474 |

|-----|-------|
| AST | 0.485 |

|----------------------------------|------------------|
| Audio-MAE (SOTA self-supervised) | 超过已有监督方法(具体数值未详) |

|---------------------------|-------|
| 你提到的 DASS_medium_AudioSet | 0.502 |

https://huggingface.co/hongroklim/omniverse-github/tree/main

https://github.com/JongSuk1/EquiAV

audio-mamba-aum

https://github.com/kaistmm/Audio-Mamba-AuM?tab=readme-ov-file

代码没测

python 复制代码
import torch
import torchaudio
from pathlib import Path

# 假设模型定义在 src/model.py 中(你需要根据实际路径替换)
from src.model import AudioMambaModel
from src.config import get_config  # 如果配置是分离的

def load_model(checkpoint_path, config_path=None, device='cuda'):
    # 初始化模型配置(如果是单文件加载则可简化)
    config = get_config(config_path) if config_path else None
    model = AudioMambaModel(config) if config else AudioMambaModel()
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device).eval()
    return model

def preprocess_audio(audio_path, target_sample_rate=16000):
    waveform, sr = torchaudio.load(audio_path)
    if sr != target_sample_rate:
        waveform = torchaudio.transforms.Resample(sr, target_sample_rate)(waveform)
    # 这里可以加入其他预处理,如分帧、增益标准化等
    return waveform

def infer(model, waveform, device='cuda'):
    waveform = waveform.to(device)
    with torch.no_grad():
        outputs = model(waveform.unsqueeze(0))  # 增加 batch 维
        probs = torch.softmax(outputs, dim=1)
        top_prob, top_label = torch.max(probs, dim=1)
    return top_label.item(), top_prob.item()

if __name__ == "__main__":
    checkpoint_path = "path/to/your/checkpoint.pth"  # 替换为你下载的 checkpoint
    config_path = None  # 如果需要可提供配置文件路径
    audio_path = "path/to/your/audio.wav"
    label_map = {0: "class_a", 1: "class_b", ...}  # 填入对应任务标签映射

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = load_model(checkpoint_path, config_path, device)
    waveform = preprocess_audio(audio_path)
    label_idx, confidence = infer(model, waveform, device)

    print(f"Predicted label: {label_map[label_idx]} (confidence: {confidence:.4f})")

Small ImageNet

These are the checkpoints for the small models with the variant Bi-Bi (c), initialized with ImageNet pretrained weights.

Dataset #Params Performance Checkpoint
Audioset (mAP) 25.5M 39.74 Link
AS-20K (mAP) 25.5M 29.17 Link
VGGSound (Acc) 25.5M 49.61 Link
VoxCeleb (Acc) 25.8M 41.78 Link
Speech Commands V2 (Acc) 25.2M 97.61 Link
Epic Sounds (Acc) 25.4M 53.45 Link

kimia_infer 音频理解

bash 复制代码
git clone https://github.com/MoonshotAI/Kimi-Audio
git submodule update --init
cd Kimi-Audio
docker build -t kimi-audio:v0.1 .

Alternatively, You can also use our pre-built image:

docker pull moonshotai/kimi-audio:v0.1

docker pull 的方式成功了。

docker run -it a49762d13a3d bash

python 复制代码
import soundfile as sf
import torch
from kimia_infer.api.kimia import KimiAudio

def main():
    # 1. 模型加载
    model_id = "moonshotai/Kimi-Audio-7B-Instruct"  # 或者 "Kimi/Kimi-Audio-7B"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    model = KimiAudio(model_path=model_id, load_detokenizer=True)
    model.to(device)
    
    # 2. 设置采样参数,调整生成行为
    sampling_params = {
        "audio_temperature": 0.8,
        "audio_top_k": 10,
        "text_temperature": 0.0,
        "text_top_k": 5,
        "audio_repetition_penalty": 1.0,
        "audio_repetition_window_size": 64,
        "text_repetition_penalty": 1.0,
        "text_repetition_window_size": 16,
    }
    
    # 3. 示例任务 A:音频转文本(ASR)
    asr_audio = "path/to/your_asr_audio.wav"  # 请替换为真实音频路径
    messages_asr = [
        {"role": "user", "message_type": "text", "content": "请转录下面这段音频:"},
        {"role": "user", "message_type": "audio", "content": asr_audio}
    ]
    
    _, text_out = model.generate(messages_asr, **sampling_params, output_type="text")
    print("ASR 输出内容:", text_out)
    
    # 4. 示例任务 B:音频问答(Audio-to-Audio/Text)
    qa_audio = "path/to/your_qa_audio.wav"  # 替换为真实音频路径
    messages_qa = [
        {"role": "user", "message_type": "audio", "content": qa_audio}
    ]
    
    wav_out, text_qa = model.generate(messages_qa, **sampling_params, output_type="both")
    output_wav = "output_generated.wav"
    
    sf.write(output_wav, wav_out.detach().cpu().view(-1).numpy(), 24000)
    print("问答生成音频已保存至:", output_wav)
    print("问答输出文本:", text_qa)

if __name__ == "__main__":
    main()

音频分类:

python 复制代码
# coding=utf-8
import glob
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_dir)
print('current_dir', current_dir)
paths = [current_dir, current_dir+'/../']
paths.append(os.path.join(current_dir, 'src'))
for path in paths:
    sys.path.insert(0, path)
    os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')

from kimia_infer.api.kimia import KimiAudio
import os
import soundfile as sf
import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument("--model_path", type=str, default="moonshotai/Kimi-Audio-7B-Instruct")
    parser.add_argument("--model_path", type=str, default="/nas/lbg/models/Kimi-Audio-7B-Instruct")
    args = parser.parse_args()

    model = KimiAudio(
        model_path=args.model_path,
        load_detokenizer=True,
    )

    sampling_params = {
        "audio_temperature": 0.8,
        "audio_top_k": 10,
        "text_temperature": 0.0,
        "text_top_k": 5,
        "audio_repetition_penalty": 1.0,
        "audio_repetition_window_size": 64,
        "text_repetition_penalty": 1.0,
        "text_repetition_window_size": 16,
    }

    # messages = [
    #     {
    #         "role": "user",
    #         "message_type": "audio",
    #         "content": "test_audios/asr_example.wav",
    #     }
    # ]

    # result = model.generate(messages, output_type="sec")  # 声音事件分类
    # print(">>> SEC 分类结果: ", result)

    # result = model.generate(messages, output_type="asc")  # 声学场景分类
    # print(">>> ASC 分类结果: ", result)
    
    base_dir=r"/nas/lbg/project/Kimi-Audio/test_audios/music/Music"

    files=glob.glob(base_dir+"/*.mp3")

    for file in files:

        messages = [
            {"role": "user", "message_type": "text", "content": "请判断这个音频属于以下哪一类: 无人声无音乐、说话、纯音乐、唱歌。只输出类别。"},
            {
                "role": "user",
                "message_type": "audio",
                "content": file,
            },
        ]

        wav, text = model.generate(messages, **sampling_params, output_type="text")
        
        file_name=os.path.basename(file)
        print(">>> 分类结果: ", text,file_name)

saurabhati/DASS_medium_AudioSet_50.2

python 复制代码
import torch
import librosa
from transformers import AutoConfig, AutoModelForAudioClassification, AutoFeatureExtractor

config = AutoConfig.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)
audio_model = AutoModelForAudioClassification.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)
feature_extractor = AutoFeatureExtractor.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)

waveform, sr = librosa.load("audio/eval/_/_/--4gqARaEJE_0.000.flac", sr=16000)
inputs = feature_extractor(waveform,sr, return_tensors='pt')

with torch.no_grad():
    logits = torch.sigmoid(audio_model(**inputs).logits)

predicted_class_ids = torch.where(logits[0] > 0.5)[0]
predicted_label = [audio_model.config.id2label[i.item()] for i in predicted_class_ids]
predicted_label
['Animal', 'Domestic animals, pets', 'Dog']

ast

模型名称 数据集 层数 - 头数 性能指标(mAP) 发布时间 核心特点
MIT/ast-finetuned-audioset-16-16-0.442 AudioSet 16 层 - 16 头 0.442 2023 早期版本,参数较多,适合复杂音频分类
MIT/ast-finetuned-audioset-14-14-0.443 AudioSet 14 层 - 14 头 0.443 2023 参数精简,性能小幅提升
MIT/ast-finetuned-audioset-12-12-0.447 AudioSet 12 层 - 12 头 0.447 2024 进一步优化结构,性能显著提升
MIT/ast-finetuned-speech-commands-v2 Speech Commands V2 轻量结构 98.1% 准确率 - 专为语音命令识别设计,实时性强

https://github.com/YuanGongND/ast

python 复制代码
import os 
import torch
from models import ASTModel 
# download pretrained model in this directory
os.environ['TORCH_HOME'] = '../pretrained_models'  
# assume each input spectrogram has 100 time frames
input_tdim = 100
# assume the task has 527 classes
label_dim = 527
# create a pseudo input: a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins 
test_input = torch.rand([10, input_tdim, 128]) 
# create an AST model
ast_mdl = ASTModel(label_dim=label_dim, input_tdim=input_tdim, imagenet_pretrain=True)
test_output = ast_mdl(test_input) 
# output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes. 
print(test_output.shape)  

gpt 生成代码:

python 复制代码
import torch
import torchaudio
from ast.model import ASTModel

# 1. 加载预训练的AST模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ASTModel(label_dim=527, input_fdim=128, input_tdim=1024, 
                 imagenet_pretrain=True, audioset_pretrain=True)
model = model.to(device)
model.eval()

# 2. 加载并预处理音频文件
audio_path = "example.wav"
waveform, sample_rate = torchaudio.load(audio_path)

# 将音频转换为单声道并重采样到16kHz
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(sample_rate, 16000)
    waveform = resampler(waveform)

# 3. 提取频谱图
fbank = torchaudio.compliance.kaldi.fbank(
    waveform, htk_compat=True, sample_frequency=16000, 
    use_energy=False, window_type='hanning', 
    num_mel_bins=128, dither=0.0, frame_shift=10
)

# 调整频谱图大小以适应模型输入
n_frames = fbank.shape[0]
p = 1024 - n_frames
if p > 0:
    # 如果太短则填充
    m = torch.nn.ZeroPad2d((0, 0, 0, p))
    fbank = m(fbank)
elif p < 0:
    # 如果太长则截断
    fbank = fbank[0:1024, :]

# 4. 标准化
fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)

# 5. 准备输入张量
input_tensor = fbank.unsqueeze(0).to(device)

# 6. 进行预测
with torch.no_grad():
    output = model(input_tensor)

# 7. 获取预测结果
probabilities = torch.sigmoid(output)
top5_prob, top5_labels = torch.topk(probabilities, 5)

# 加载标签(假设我们有标签列表)
labels = [...]  # 这里应该是你的分类标签列表

print("Top 5 predictions:")
for i in range(5):
    print(f"{labels[top5_labels[0][i]]}: {top5_prob[0][i]*100:.2f}%")

AudioClassification-Pytorch没有模型,自己训练:

https://github.com/yeyupiaoling/AudioClassification-Pytorch

相关推荐
后端小肥肠2 分钟前
Coze+ComfyUI 实战:视频制作成本降10 倍,高质量成片这么做
人工智能·aigc·coze
悠哉悠哉愿意2 分钟前
【Python语法基础学习笔记】if语句
笔记·python·学习
Q_Q196328847511 分钟前
python的电影院座位管理可视化数据分析系统
开发语言·spring boot·python·django·flask·node.js·php
BYSJMG23 分钟前
计算机大数据毕业设计推荐:基于Hadoop+Spark的食物口味差异分析可视化系统【源码+文档+调试】
大数据·hadoop·分布式·python·spark·django·课程设计
爱分享的飘哥34 分钟前
第七十章:告别“手写循环”噩梦!Trainer结构搭建:PyTorch Lightning让你“一键炼丹”!
人工智能·pytorch·分布式训练·lightning·accelerate·训练框架·trainer
杜子不疼.38 分钟前
《Python学习之第三方库:开启无限可能》
开发语言·python·学习
阿里云大数据AI技术1 小时前
PAIFuser:面向图像视频的训练推理加速框架
人工智能·机器学习
盛世隐者1 小时前
【深度学习】pytorch深度学习框架的环境配置
人工智能·pytorch·深度学习
说私域1 小时前
基于开源链动2+1模式AI智能名片S2B2C商城小程序的流量转化策略研究
人工智能·小程序
funfan05171 小时前
GPT-5博士级AI使用教程及国内平替方案
人工智能·gpt