记录训练呼叫中心专有ASR模型过程

1. 训练原因

之前有做呼叫中心系统的公司找到过我,询问我如何提高语音识别模型在通话录音的转写准确率,我知道市面上开源的模型,一般都是使用通用数据训练的模型,对于电话场景就识别的不怎么准确,特别是通话中一般存储的录音为8k采样率的音频,并且通常有很大部分音频存在很严重的噪音,这也是导致最终识别不准的原因之一。因此,这里使用了大量的通话录音来训练,并且加入了其它更准确的数据集进行训练,有超过200小时的数据集,主要包含弱标注10万条数据,和强标注6万条数据,进行混合训练。
关于本文章的更详细内容可点击 这里

视频可在浏览器中访问下面地址:

shell 复制代码
https://www.bilibili.com/video/BV1TpSzByEeJ

2. 数据准备

这里准备了16万多条音频,需要注意每条音频都是切分后的音频片段,并非长音频。

这里使用自己编写的python脚本,对输入的音频来切分,并自动打标,然后进行人工审查。还有一种方式是调用商用的API接口,实现自动打标,商用级别的ASR接口一般准确率是OK的,这里我调用了讯飞中的极速版语音识别大模型接口,进行自动打标。上面两种方式打标都一样会生成一个excel表格,表格中有三列,A列是音频的唯一标识,B列是音频对应的文本内容,C列是音频的绝对路径。

2.1 方式一:使用其它开源ASR模型进行自动打标

下面是具体的代码实现:

python 复制代码
#!/usr/bin/env python
# _*_ coding utf-8 _*_
# @Time: 2025/11/23 2:24
# @Author: Luke Ewin
# @Blog: https://blog.lukeewin.top
import os
import argparse
from datetime import datetime
import ffmpeg
from funasr import AutoModel
import openpyxl
from openpyxl import Workbook

asr_model_path = os.path.join("models", "asr")
asr_model_revision = "v2.0.4"
vad_model_path = os.path.join("models", "vad")
vad_model_revision = "v2.0.4"
punc_model_path = os.path.join("models", "punc")
punc_model_revision = "v2.0.4"
hotword_path = "hotword.txt"
ngpu = 1
device = "cuda"
ncpu = 4

model = AutoModel(model=asr_model_path,
                  model_revision=asr_model_revision,
                  vad_model=vad_model_path,
                  vad_model_revision=vad_model_revision,
                  punc_model=punc_model_path,
                  punc_model_revision=punc_model_revision,
                  ngpu=ngpu,
                  ncpu=ncpu,
                  device=device,
                  disable_pbar=True,
                  disable_log=True,
                  disable_update=True
                  )

param_dict = {"sentence_timestamp": True, "batch_size_s": 300}
if os.path.exists(hotword_path):
    with open(hotword_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        lines = [line.strip() for line in lines]
    hotword = " ".join(lines)
    param_dict["hotword"] = hotword

support_audio_format = ['.mp3', '.m4a', '.aac', '.ogg', '.wav', '.flac', '.wma', '.aif', 'webm', '.WAV', '.MP3']


def get_audio_files(directory):
    """获取目录下所有支持的音频文件"""
    audio_files = []
    if not directory or not os.path.exists(directory):
        return audio_files

    for root, dirs, files in os.walk(directory):
        for file in files:
            if any(file.lower().endswith(fmt) for fmt in support_audio_format):
                audio_files.append(os.path.join(root, file))
    return audio_files


def to_date(milliseconds):
    """将时间戳转换为SRT格式的时间"""
    from datetime import timedelta
    time_obj = timedelta(milliseconds=milliseconds)
    return f"{time_obj.seconds // 3600:02d}:{(time_obj.seconds // 60) % 60:02d}:{time_obj.seconds % 60:02d},{time_obj.microseconds // 1000:03d}"


def to_seconds(milliseconds):
    """将毫秒转换为秒"""
    return milliseconds / 1000.0


def merge_short_segments(segments, min_duration=3000):
    """合并短于指定时长的音频片段"""
    merged_segments = []
    current_segment = None

    for segment in segments:
        start_ms, end_ms, text = segment
        duration = end_ms - start_ms

        # 如果当前片段为空,初始化
        if current_segment is None:
            current_segment = {
                'start': start_ms,
                'end': end_ms,
                'texts': [text]
            }
        else:
            # 如果当前片段时长小于最小要求,尝试合并
            if (current_segment['end'] - current_segment['start']) < min_duration:
                current_segment['end'] = end_ms
                current_segment['texts'].append(text)
            else:
                # 当前片段已经足够长,保存并开始新片段
                merged_segments.append((
                    current_segment['start'],
                    current_segment['end'],
                    ''.join(current_segment['texts'])
                ))
                current_segment = {
                    'start': start_ms,
                    'end': end_ms,
                    'texts': [text]
                }

    # 处理最后一个片段
    if current_segment is not None:
        merged_segments.append((
            current_segment['start'],
            current_segment['end'],
            ''.join(current_segment['texts'])
        ))

    return merged_segments


def init_excel_file(excel_path):
    """初始化Excel文件并创建表头"""
    wb = Workbook()
    ws = wb.active
    ws.title = "音频片段信息"
    # 添加表头
    ws.append(["音频ID", "文本内容", "文件路径"])
    wb.save(excel_path)
    return excel_path


def append_to_excel(excel_path, data_rows):
    """向Excel文件追加数据"""
    wb = openpyxl.load_workbook(excel_path)
    ws = wb.active
    
    for row in data_rows:
        ws.append(row)
    
    wb.save(excel_path)


def transcript_only(input_dir, output_dir):
    """仅进行语音识别,不切分音频"""
    if not os.path.exists(input_dir):
        print(f"错误: 输入目录不存在: {input_dir}")
        return
    
    audio_files = get_audio_files(input_dir)
    if not audio_files:
        print(f"输入目录中没有支持的音频文件: {input_dir}")
        return

    total_files = len(audio_files)
    print(f"找到 {total_files} 个音频文件")

    for index, audio in enumerate(audio_files, 1):
        print(f"处理进度: {index}/{total_files} - {os.path.basename(audio)}")
        
        if os.path.exists(audio):
            audio_name = os.path.splitext(os.path.basename(audio))[0]
            current_date = datetime.now().strftime("%Y-%m-%d")
            target_dir = os.path.join(output_dir, current_date, audio_name)
            os.makedirs(target_dir, exist_ok=True)

            try:
                audio_bytes, _ = (
                    ffmpeg.input(audio, threads=0, hwaccel='cuda')
                    .output("-", format="wav", acodec="pcm_s16le", ac=1, ar=16000)
                    .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
                )
                res = model.generate(input=audio_bytes, **param_dict)
                rec_result = res[0]
                asr_result_text = rec_result['text']
                
                if asr_result_text != '':
                    txt_list = []
                    for sentence in rec_result["sentence_info"]:
                        start = to_date(sentence["start"])
                        end = to_date(sentence["end"])
                        text = sentence["text"]
                        txt_list.append((start, end, text))
                    
                    count = 1
                    srt_path = os.path.join(target_dir, f"{audio_name}.srt")
                    with open(srt_path, 'w', encoding='utf-8') as f:
                        for start, end, text in txt_list:
                            f.write(f"{count}\n{start} --> {end}\n{text}\n\n")
                            count += 1
                    print(f"  已生成SRT文件: {srt_path}")
                else:
                    print(f"  警告: 未识别到文本内容")
                    
            except Exception as e:
                print(f"  处理文件时出错: {e}")

    print("处理完成!")


def transcript_and_split(input_dir, output_dir, excel_path):
    """进行语音识别并切分音频,每处理完一个文件就写入Excel"""
    if not os.path.exists(input_dir):
        print(f"错误: 输入目录不存在: {input_dir}")
        return
    
    # 初始化Excel文件
    if not os.path.exists(excel_path):
        init_excel_file(excel_path)
        print(f"已创建Excel文件: {excel_path}")
    else:
        print(f"使用现有Excel文件: {excel_path}")

    audio_files = get_audio_files(input_dir)
    if not audio_files:
        print(f"输入目录中没有支持的音频文件: {input_dir}")
        return

    total_files = len(audio_files)
    print(f"找到 {total_files} 个音频文件")

    for index, audio in enumerate(audio_files, 1):
        print(f"处理进度: {index}/{total_files} - {os.path.basename(audio)}")
        
        if os.path.exists(audio):
            audio_name = os.path.splitext(os.path.basename(audio))[0]
            audio_ext = os.path.splitext(audio)[1]
            current_date = datetime.now().strftime("%Y-%m-%d")
            target_dir = os.path.join(output_dir, current_date, audio_name)
            os.makedirs(target_dir, exist_ok=True)

            try:
                # 进行ASR识别
                res = model.generate(input=audio, **param_dict)
                if not res or len(res) == 0:
                    asr_result_text = ""
                else:
                    rec_result = res[0]
                    if 'text' in rec_result:
                        asr_result_text = rec_result['text']
                    else:
                        asr_result_text = ""
                
                data_rows = []
                if asr_result_text != '':
                    # 获取原始片段信息
                    original_segments = []
                    for sentence in rec_result["sentence_info"]:
                        start_ms = sentence["start"]
                        end_ms = sentence["end"]
                        text = sentence["text"]
                        original_segments.append((start_ms, end_ms, text))

                    # 合并短片段
                    merged_segments = merge_short_segments(original_segments, min_duration=3000)

                    # 切分音频并准备Excel数据
                    count = 1
                    for start_ms, end_ms, text in merged_segments:
                        # 切分音频
                        segment_filename = f"{audio_name}_{count}{audio_ext}"
                        segment_path = os.path.join(target_dir, segment_filename)

                        try:
                            (
                                ffmpeg.input(audio)
                                .output(segment_path,
                                        ss=to_seconds(start_ms),
                                        to=to_seconds(end_ms))
                                .run(overwrite_output=True, quiet=True)
                            )

                            # 准备Excel数据行
                            audio_id = f"{audio_name}_{count}"
                            data_rows.append([audio_id, text, os.path.abspath(segment_path)])

                        except Exception as e:
                            print(f"  切分音频失败: {e}")

                        count += 1
                    
                    # 将当前文件的所有数据行写入Excel
                    if data_rows:
                        append_to_excel(excel_path, data_rows)
                        print(f"  已处理 {len(data_rows)} 个片段并写入Excel")
                else:
                    print(f"  警告: 未识别到文本内容")
                    
            except Exception as e:
                print(f"  处理文件时出错: {e}")

    print(f"处理完成!Excel文件已保存: {excel_path}")


def main():
    parser = argparse.ArgumentParser(description='语音识别工具 - https://blog.lukeewin.top')
    parser.add_argument('-i', '--input', required=True, help='输入目录路径')
    parser.add_argument('-o', '--output', required=True, help='输出目录路径')
    parser.add_argument('-m', '--mode', choices=['transcript', 'split'], default='split', 
                       help='处理模式: transcript(仅识别) 或 split(识别并切分)')
    parser.add_argument('-e', '--excel', default='asr_results.xlsx', 
                       help='Excel输出文件路径(仅split模式有效)')
    
    args = parser.parse_args()
    
    if args.mode == 'transcript':
        transcript_only(args.input, args.output)
    else:
        transcript_and_split(args.input, args.output, args.excel)


if __name__ == '__main__':
    main()

执行之前需要确保你python环境中已安装funasr依赖,以及ffmpeg-python依赖。

为了方便不懂python的人应用,这里我使用pyinstaller打包成一个GUI小工具,下载地址:

markdown 复制代码
通过网盘分享的文件:models.zip等2个文件
链接: https://pan.baidu.com/s/18QXq2u118es48zCg6b9D-Q?pwd=nkv2 提取码: nkv2

第一次使用这个小工具,会提示授权,按照提示发送对应的机器码给客服。

如果你使用的是cli方式运行,那么也可以从这个网盘中下载models.zip压缩包,这是模型文件,放到当前脚本的models目录下解压,有三个目录,一个asr,一个vad还有一个punc。同时还支持热词,在当前脚本目录中创建一个名为hotword.txt的文件,里面可以写入热词,每个热词写一行,这个hotword.txt文件大小推荐小于等于1KB,过多的热词反而会影响识别的准确率。

注意:该小工具只适用于普通话自动打标,如果是方言,可以考虑使用第二种方式。

2.2 方式二:调用讯飞API进行自动打标

如果还没调用过讯飞的极速版语音识别模型接口,那么有10小时的免费额度,如果没有,就需要你自己花钱购买额度。比花钱找人打标效率要高得多。

下面具体的代码:

python 复制代码
# !/usr/bin/env python
# _*_ coding utf-8 _*_
# @Time: 2025/11/26 2:10
# @Author: Luke Ewin
# @Blog: https://blog.lukeewin.top
import ffmpeg
from fileupload import seve_file
import requests
import datetime
import hashlib
import base64
import hmac
import json
import os
import re
from openpyxl import Workbook

path_pwd = os.path.split(os.path.realpath(__file__))[0]
os.chdir(path_pwd)


def to_seconds(milliseconds):
    """将毫秒转换为秒"""
    milliseconds = float(milliseconds)
    return milliseconds / 1000.0


# 创建和查询
class get_result(object):
    def __init__(self, appid, apikey, apisecret):
        # 以下为POST请求
        self.Host = "ost-api.xfyun.cn"
        self.RequestUriCreate = "/v2/ost/pro_create"
        self.RequestUriQuery = "/v2/ost/query"
        # 设置url
        if re.match("^\d", self.Host):
            self.urlCreate = "http://" + self.Host + self.RequestUriCreate
            self.urlQuery = "http://" + self.Host + self.RequestUriQuery
        else:
            self.urlCreate = "https://" + self.Host + self.RequestUriCreate
            self.urlQuery = "https://" + self.Host + self.RequestUriQuery
        self.HttpMethod = "POST"
        self.APPID = appid
        self.Algorithm = "hmac-sha256"
        self.HttpProto = "HTTP/1.1"
        self.UserName = apikey
        self.Secret = apisecret

        # 设置当前时间
        cur_time_utc = datetime.datetime.utcnow()
        self.Date = self.httpdate(cur_time_utc)
        # 设置测试音频文件
        self.BusinessArgsCreate = {
            "language": "zh_cn",
            "accent": "mandarin",
            "domain": "pro_ost_ed",
            # "callback_url": "http://IP:端口号/xxx/"
        }

    def img_read(self, path):
        with open(path, 'rb') as fo:
            return fo.read()

    def hashlib_256(self, res):
        m = hashlib.sha256(bytes(res.encode(encoding='utf-8'))).digest()
        result = "SHA-256=" + base64.b64encode(m).decode(encoding='utf-8')
        return result

    def httpdate(self, dt):
        """
        Return a string representation of a date according to RFC 1123
        (HTTP/1.1).
        The supplied date must be in UTC.
        """
        weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"][dt.weekday()]
        month = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep",
                 "Oct", "Nov", "Dec"][dt.month - 1]
        return "%s, %02d %s %04d %02d:%02d:%02d GMT" % (weekday, dt.day, month,
                                                        dt.year, dt.hour, dt.minute, dt.second)

    def generateSignature(self, digest, uri):
        signature_str = "host: " + self.Host + "\n"
        signature_str += "date: " + self.Date + "\n"
        signature_str += self.HttpMethod + " " + uri \
                         + " " + self.HttpProto + "\n"
        signature_str += "digest: " + digest
        signature = hmac.new(bytes(self.Secret.encode('utf-8')),
                             bytes(signature_str.encode('utf-8')),
                             digestmod=hashlib.sha256).digest()
        result = base64.b64encode(signature)
        return result.decode(encoding='utf-8')

    def init_header(self, data, uri):
        digest = self.hashlib_256(data)
        sign = self.generateSignature(digest, uri)
        auth_header = 'api_key="%s",algorithm="%s", ' \
                      'headers="host date request-line digest", ' \
                      'signature="%s"' \
                      % (self.UserName, self.Algorithm, sign)
        headers = {
            "Content-Type": "application/json",
            "Accept": "application/json",
            "Method": "POST",
            "Host": self.Host,
            "Date": self.Date,
            "Digest": digest,
            "Authorization": auth_header
        }
        return headers

    def get_create_body(self, fileurl):
        post_data = {
            "common": {"app_id": self.APPID},
            "business": self.BusinessArgsCreate,
            "data": {
                "audio_src": "http",
                "audio_url": fileurl,
                "encoding": "raw"
            }
        }
        body = json.dumps(post_data)
        return body

    def get_query_body(self, task_id):
        post_data = {
            "common": {"app_id": self.APPID},
            "business": {
                "task_id": task_id,
            },
        }
        body = json.dumps(post_data)
        return body

    def call(self, url, body, headers):
        try:
            response = requests.post(url, data=body, headers=headers, timeout=8)
            status_code = response.status_code
            interval = response.elapsed.total_seconds()
            if status_code != 200:
                info = response.content
                return info
            else:
                resp_data = json.loads(response.text)
                return resp_data
        except Exception as e:
            print("Exception :%s" % e)

    def task_create(self, fileurl):
        body = self.get_create_body(fileurl)
        headers_create = self.init_header(body, self.RequestUriCreate)
        task_id = self.call(self.urlCreate, body, headers_create)
        print(task_id)
        return task_id

    def task_query(self, task_id, fileurl):
        if task_id:
            query_body = self.get_query_body(task_id)
            headers_query = self.init_header(query_body, self.RequestUriQuery)
            result = self.call(self.urlQuery, query_body, headers_query)
            return result

    def get_fileurl(self, file_path):
        # 文件上传
        api = seve_file.SeveFile(app_id=self.APPID, api_key=self.UserName, api_secret=self.Secret,
                                 upload_file_path=file_path)
        file_total_size = os.path.getsize(file_path)
        if file_total_size < 31457280:
            print("-----不使用分块上传-----")
            fileurl = api.gene_params('/upload')['data']['url']
        else:
            print("-----使用分块上传-----")
            fileurl = api.gene_params('/mpupload/upload')
        return fileurl

    def process_result(self, result, file_path, output_dir, ws):
        """处理ASR结果并切割音频"""
        if result.get("code") != 0:
            print(f'API返回错误:{result.get("message")}')
            return

        lattice_data = result.get("data", {}).get("result", {}).get("lattice", [])
        file_name = os.path.splitext(os.path.basename(file_path))[0]
        file_ext = os.path.splitext(os.path.basename(file_path))[1]

        # 确保输出目录存在
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        count = 1
        for segment in lattice_data:
            # 获取整个段落的时间信息
            begin_time = segment.get("begin")
            end_time = segment.get("end")

            # 获取句子内容
            json_1best = segment.get("json_1best", {})
            st_data = json_1best.get("st", {})
            rt_data = st_data.get("rt", [])

            for rt_item in rt_data:
                ws_data = rt_item.get("ws", [])

                # 拼接句子内容
                sentence_text = ""
                for word_segment in ws_data:
                    cw_data = word_segment.get("cw", [])
                    for word in cw_data:
                        word_text = word.get("w", "")
                        # 过滤掉空字符串和标点符号(如果需要的话)
                        if word_text and word_text.strip():
                            sentence_text += word_text

                # 如果句子不为空,则处理音频切割
                if sentence_text.strip():
                    # 检查字数是否小于3
                    if len(sentence_text.strip()) < 3:
                        print(f"跳过字数小于3的句子: '{sentence_text}'")
                        continue

                    audio_id = f"{file_name}_{count}"
                    segment_path = os.path.join(output_dir, audio_id + file_ext)

                    try:
                        # 切割音频
                        ffmpeg.input(file_path).output(
                            segment_path,
                            ss=to_seconds(begin_time),
                            to=to_seconds(end_time)
                        ).run(overwrite_output=True, quiet=True)

                        # 添加到Excel
                        ws.append([audio_id, sentence_text, os.path.abspath(segment_path)])
                        print(f"已处理片段 {count}: {sentence_text}")
                        count += 1
                    except Exception as e:
                        print(f"切割音频失败: {e}")
                        # 如果切割失败,确保删除可能创建的不完整文件
                        if os.path.exists(segment_path):
                            try:
                                os.remove(segment_path)
                            except:
                                pass

    def get_result(self, file_path, output_dir, ws):
        # 获取文件URL
        print("\n------ 上传文件 -------")
        fileurl = self.get_fileurl(file_path)

        # 创建订单
        print("\n------ 创建任务 -------")
        task_response = self.task_create(fileurl)
        if not task_response or 'data' not in task_response:
            print("创建任务失败")
            return

        task_id = task_response['data']['task_id']

        # 查询任务
        print("\n------ 查询任务 -------")
        print("任务转写中······")
        while True:
            result = self.task_query(task_id, fileurl)
            if not result:
                print("查询结果为空")
                break

            if isinstance(result, dict):
                task_status = result['data']['task_status']
                if task_status == '1':  # 处理中
                    print("任务处理中,等待5秒后重试...")
                    import time
                    time.sleep(5)
                    continue
                elif task_status == '2':  # 等待中
                    print("任务等待中,等待5秒后重试...")
                    import time
                    time.sleep(5)
                    continue
                else:  # 完成或其他状态
                    print("转写完成")
                    self.process_result(result, file_path, output_dir, ws)
                    break
            elif isinstance(result, bytes):
                print("发生错误:", result.decode('utf-8'))
                break
            else:
                print("未知响应类型")
                break


if __name__ == '__main__':
    # 输入讯飞开放平台的appid,secret、key和文件路径
    appid = ""
    apikey = ""
    apisecret = ""

    file_dir = r'E:\Datasets\Telephone\audio\2025-11-25'
    output_dir = r'E:\Datasets\Telephone\audio_process_after_by_funasrGUI\2025-11-25'

    # 创建Excel工作簿
    wb = Workbook()
    ws = wb.active
    ws.title = "音频片段信息"
    # 添加表头
    ws.append(["音频唯一标识", "句子内容", "音频片段绝对路径"])

    # 遍历文件目录
    for root, dirs, files in os.walk(file_dir):
        for file in files:
            file_path = os.path.join(root, file)
            print(f"\n处理文件: {file_path}")

            gClass = get_result(appid, apikey, apisecret)
            gClass.get_result(file_path, output_dir, ws)

    # 保存Excel文件
    current_date = datetime.datetime.now().strftime("%Y-%m-%d")
    excel_dir = os.path.join(output_dir, current_date)
    os.makedirs(excel_dir, exist_ok=True)
    excel_path = os.path.join(excel_dir, f"音频片段信息_{datetime.datetime.now().strftime('%H%M%S')}.xlsx")
    wb.save(excel_path)
    print(f"\nExcel文件已保存: {excel_path}")

然后创建一个目录fileupload,并写入下面seve_file.py代码:

python 复制代码
#!/usr/bin/python3
# -*- coding:utf-8 -*-

import json
import math
import os
import time
from datetime import datetime
from wsgiref.handlers import format_date_time
from time import mktime
import hashlib
import base64
import hmac
from urllib.parse import urlparse
import requests
from urllib3 import encode_multipart_formdata

lfasr_host = 'http://upload-ost-api.xfyun.cn/file'
# 请求的接口名
api_init = '/mpupload/init'
api_upload = '/upload'
api_cut = '/mpupload/upload'
api_cut_complete = '/mpupload/complete'
api_cut_cancel = '/mpupload/cancel'
# 文件分片大小5M
file_piece_sice = 5242880


# 文件上传
class SeveFile:
    def __init__(self, app_id, api_key, api_secret,upload_file_path):
        self.app_id = app_id
        self.api_key = api_key
        self.api_secret = api_secret
        self.request_id = '0'
        self.upload_file_path = upload_file_path
        self.cloud_id = '0'

    # request_id处理
    def get_request_id(self):
        return time.strftime("%Y%m%d%H%M")

    # header处理
    def hashlib_256(self, data):
        m = hashlib.sha256(bytes(data.encode(encoding='utf-8'))).digest()
        digest = "SHA-256=" + base64.b64encode(m).decode(encoding='utf-8')
        return digest

    # header处理
    def assemble_auth_header(self, requset_url, file_data_type, method="", api_key="", api_secret="", body=""):
        u = urlparse(requset_url)
        host = u.hostname
        path = u.path
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))
        digest = "SHA256=" + self.hashlib_256('')
        signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1\ndigest: {}".format(host, date, method, path, digest)
        signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()
        signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
        authorization = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
            api_key, "hmac-sha256", "host date request-line digest", signature_sha)
        headers = {
            "host": host,
            "date": date,
            "authorization": authorization,
            "digest": digest,
            'content-type': file_data_type,
        }
        return headers

    # post请求api
    def call(self,  url, file_data, file_data_type):
        api_key = self.api_key
        api_secret = self.api_secret
        headerss = self.assemble_auth_header(url, file_data_type, method="POST",
                                             api_key= api_key,api_secret = api_secret, body = file_data)
        try:
            resp = requests.post(url, headers=headerss, data=file_data, timeout=8)
            print("该片上传成功.状态:",resp.status_code, resp.text)
            return resp.json()
        except Exception as e:
            print("该片上传失败!Exception :%s" % e)
            return False


    # 分块上传完成
    def upload_cut_complete(self, body_dict):
        file_data_type = 'application/json'
        url = lfasr_host + api_cut_complete
        fileurl = self.call(url, json.dumps(body_dict), file_data_type)
        fileurl = fileurl['data']['url']
        print("任务上传结束")
        return fileurl



    # 根据不同的apiname生成不同的参数,本示例中未使用全部参数您可在官网(https://aidocs.xfyun.cn/docs/ost/%E5%A4%9A%E7%A7%9F%E6%88%B7%E6%96%87%E4%BB%B6%E4%B8%8A%E4%BC%A0%E6%8E%A5%E5%8F%A3%E6%96%87%E6%A1%A3.html)查看后选择适合业务场景的进行更换
    def gene_params(self, apiname):
        appid = self.app_id
        request_id = self.get_request_id()
        upload_file_path = self.upload_file_path
        cloud_id = self.cloud_id
        body_dict = {}
        # 上传文件api
        if apiname == api_upload:
            try:
                with open(upload_file_path, mode='rb') as f:
                    file = {
                        "data": (upload_file_path, f.read()),
                        "app_id": appid,
                        "request_id": request_id,
                    }
                    print('文件:', upload_file_path, ' 文件大小:', os.path.getsize(upload_file_path))
                    encode_data = encode_multipart_formdata(file)
                    #print("----",encode_data)
                    file_data = encode_data[0]
                    file_data_type = encode_data[1]
                url = lfasr_host + api_upload
                fileurl = self.call(url, file_data, file_data_type)
                #print("文件上传参数",file_data)
                return fileurl
            except FileNotFoundError:  # 文件不能找到的异常处理
                print("Sorry!The file " + upload_file_path + " can't find.")
            # 预处理api
        elif apiname == api_init:
            body_dict['app_id'] = appid
            body_dict['request_id'] = request_id
            body_dict['cloud_id'] = cloud_id
            url = lfasr_host + api_init
            file_data_type = 'application/json'
            return self.call(url, json.dumps(body_dict), file_data_type)
        elif apiname == api_cut:
            # 预处理
            upload_prepare = self.prepare_request()
            if upload_prepare:
                upload_id = upload_prepare['data']['upload_id']
            # 分块上传
            self.do_upload(upload_file_path, upload_id)
            body_dict['app_id'] = appid
            body_dict['request_id'] = request_id
            body_dict['upload_id'] = upload_id
            # 分块上传完成
            fileurl = self.upload_cut_complete(body_dict)
            print("分片上传地址:",fileurl)
            return fileurl



    # 预处理
    def prepare_request(self):
        return self.gene_params(apiname=api_init)


    # 分片上传
    def do_upload(self, file_path, upload_id):
        file_total_size = os.path.getsize(file_path)
        chunk_size = file_piece_sice
        chunks = math.ceil(file_total_size / chunk_size)
        appid = self.app_id
        request_id = self.get_request_id()
        upload_file_path = self.upload_file_path
        slice_id = 1

        print('文件:', file_path, ' 文件大小:', file_total_size, ' 分块大小:', chunk_size, ' 分块数:', chunks)

        with open(file_path, mode='rb') as content:
            while slice_id <= chunks:
                print('chunk',slice_id )
                if (slice_id-1) + 1 == chunks:
                    current_size = file_total_size % chunk_size
                else:
                    current_size = chunk_size

                file = {
                    "data": (upload_file_path, content.read(current_size)),
                    "app_id": appid,
                    "request_id": request_id,
                    "upload_id": upload_id,
                    "slice_id": slice_id,
                }

                encode_data = encode_multipart_formdata(file)
                file_data = encode_data[0]
                file_data_type = encode_data[1]
                url = lfasr_host + api_cut

                resp = self.call(url, file_data, file_data_type)
                count = 0
                while not resp and (count<3):
                    print("上传重试")
                    resp = self.call(url, file_data, file_data_type)
                    count = count + 1
                    time.sleep(1)
                if not resp:
                    quit()
                slice_id = slice_id + 1

运行上面代码后也生成一个excel表格,里面也是一样的三列数据。

这里需要注意把第一行的标题删掉。

2.3 生成scp和txt文件

经过上面的方式,生成了excel表格,下面就对这个excel表格进行处理,把excel表格转换为train_wav.scp和train_text.txt两个文件,其中train_wav.scp存储的是音频的路径,train_text.txt存储的是音频的文本内容,而这两个文件的第一列,也就是音频的唯一标识,我称之为音频ID,通过音频ID把音频文件和对应的文本内容关联起来。

具体代码可看:可看xlsx_to_funasr.py中的代码实现。

shell 复制代码
git@github.com:lukeewin/Audio_Segment.git

2.4 生成train.jsonl文件

这个使用funasr提供的命令生成一个train.jsonl文件,命令如下:

shell 复制代码
sensevoice2jsonl \
++scp_file_list='["/root/autodl-tmp/datasets/Telephone/train_data/train_wav.scp", "/root/autodl-tmp/datasets/Telephone/train_data/train_text.txt", "/root/autodl-tmp/datasets/Telephone/train_data/train_text_language.txt", "/root/autodl-tmp/datasets/Telephone/train_data/train_emo.txt", "/root/autodl-tmp/datasets/Telephone/train_data/train_event.txt"]' \
++data_type_list='["source", "target", "text_language", "emo_target", "event_target"]' \
++jsonl_file_out="/root/autodl-tmp/datasets/Telephone/train_data/train.jsonl"

上面这个命令是针对提供了source, target, text_language, emo_target和event_target的文件,如果只有source和target文件,那么就需要通过SenseVoiceSmall进行自动打标,命令如下:

shell 复制代码
sensevoice2jsonl \
++scp_file_list='["/root/autodl-tmp/datasets/Telephone/train_data/train_wav.scp", "/root/autodl-tmp/datasets/Telephone/train_data/train_text.txt"]' \
++data_type_list='["source", "target"]' \
++jsonl_file_out="/root/autodl-tmp/datasets/Telephone/train_data/train.jsonl" \
++model_dir='iic/SenseVoiceSmall'

2.5 生成val.jsonl文件

经过上面的命令,已经产生了train.jsonl文件,然后我们需要基于这个文件进行随机提取20%的数据作为val.jsonl,剩下的80%作为train.jsonl。命令如下:

shell 复制代码
mv train.jsonl train_all.jsonl
python random_process_jsonl.py

下面是random_process_jsonl.py代码:

python 复制代码
# !/usr/bin/env python
# _*_ coding utf-8 _*_
# @Time: 2025/11/28 16:16
# @Author: Luke Ewin
# @Blog: https://blog.lukeewin.top
import random


def process_jsonl_file(input_file, num_lines, output_file1, output_file2):
    """
    处理JSONL文件,随机抽取指定行数,并分别保存抽取的和剩余的内容

    Args:
        input_file: 输入的JSONL文件路径
        num_lines: 要随机抽取的行数
        output_file1: 保存抽取内容的文件路径
        output_file2: 保存剩余内容的文件路径
    """

    try:
        # 读取原始文件的所有行
        with open(input_file, 'r', encoding='utf-8') as f:
            all_lines = f.readlines()

        total_lines = len(all_lines)
        print(f"原始文件共有 {total_lines} 行")

        # 检查请求的行数是否超过文件总行数
        if num_lines > total_lines:
            print(f"警告:请求的行数({num_lines})超过文件总行数({total_lines}),将抽取所有行")
            num_lines = total_lines

        # 随机选择行索引
        selected_indices = random.sample(range(total_lines), num_lines)
        selected_indices_set = set(selected_indices)

        # 分离选中的行和剩余的行
        selected_lines = []
        remaining_lines = []

        for i, line in enumerate(all_lines):
            if i in selected_indices_set:
                selected_lines.append(line)
            else:
                remaining_lines.append(line)

        # 保存选中的行
        with open(output_file1, 'w', encoding='utf-8') as f:
            f.writelines(selected_lines)
            f.flush()

        # 保存剩余的行
        with open(output_file2, 'w', encoding='utf-8') as f:
            f.writelines(remaining_lines)
            f.flush()

        print(f"成功处理完成!")
        print(f"随机抽取了 {len(selected_lines)} 行保存到: {output_file1}")
        print(f"剩余 {len(remaining_lines)} 行保存到: {output_file2}")

    except FileNotFoundError:
        print(f"错误:输入文件 '{input_file}' 不存在")
    except Exception as e:
        print(f"处理过程中出现错误: {str(e)}")


def main():
    input_file = input("输入的train.jsonl文件路径:")
    num_lines = int(input("请输入要抽取多少行:"))
    output_file1 = input("保存抽取内容的文件路径:")
    output_file2 = input("保存剩余内容的文件路径:")
    process_jsonl_file(input_file, num_lines, output_file1, output_file2)


if __name__ == "__main__":
    main()

有四个参数输入,第一个是输入要处理的train_all.jsonl文件路径,第二个是要随机抽取多少条数据作为val集,第三是输入要保存val.jsonl文件路径,第四是输入要保存train.jsonl文件路径。

3. 开始训练

经过上面步骤的准备,现在可以开始训练了。我这里使用的是单张4090显卡,因此需要修改shell脚本,把,1删除。

同时把下面的内容记得修改为你自己的路径。

shell 复制代码
train_data="${data_dir}/train.jsonl"
val_data="${data_dir}/val.jsonl"

修改训练参数。

shell 复制代码
++dataset_conf.batch_size=6000
++dataset_conf.num_workers=4
++train_conf.max_epoch=50

根据自己服务器中硬件配置修改,比如你的显存有很多,那么可以把batch_size设置大些,如果你的CPU有16核心,那么可以把num_workers的值修改为16,而max_epoch是设置训练的轮次,你可以先保留这个默认值,你可以边训练边观察loss值的变化,看是否在下降,如果没有下降,可调整lr学习率。如果loss值在下降,说明训练是有效果的,如果50轮次训练完成之后这个loss值还是很大,或者还能继续下降,那么就把这个max_epoch值改大。

4. 测试

经过上面的步骤之后,等待模型训练结束,就可以使用下面代码做测试。

python 复制代码
# !/usr/bin/env python
# _*_ coding utf-8 _*_
# @Time: 2025/11/27 14:23
# @Author: Luke Ewin
# @Blog: https://blog.lukeewin.top
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

model_dir = "iic/SenseVoiceSmall"
model_dir1 = "/data/funasr/models/hangkong"


model = AutoModel(
    model=model_dir,
    vad_model="fsmn-vad",
    vad_kwargs={"max_single_segment_time": 30000},
    device="cuda:0",
)

model1 = AutoModel(
    model=model_dir1,
    vad_model="fsmn-vad",
    vad_kwargs={"max_single_segment_time": 30000},
    device="cuda:0",
)

while True:
    audio = input("请输入要处理的音频,如输入exit,则退出程序")
    if audio == 'exit':
        exit(0)
    else:
        res = model.generate(
            input=audio,
            cache={},
            language="auto",
            use_itn=True,
            batch_size_s=60,
            merge_vad=True,
            merge_length_s=15,
            output_timestamp=True,
        )

        res1 = model1.generate(
            input=audio,
            cache={},
            language="auto",
            use_itn=True,
            batch_size_s=60,
            merge_vad=True,
            merge_length_s=15,
            output_timestamp=True,
        )

        text = rich_transcription_postprocess(res[0]["text"])
        print(f"训练前:{text}")

        text1 = rich_transcription_postprocess(res1[0]["text"])
        print(f"训练后:{text1}")

注意:记得把model_dir1的值修改为你自己服务器中训练后的模型所在路径。

运行这个python脚本可以测试训练前后效果。

5. 字错率评估

字错率评估可以看我之前写的一篇文件,字错率评估文章

下面是我之前写的评估代码的一部分。

python 复制代码
# !/usr/bin/env python
# _*_ coding utf-8 _*_
# @Time: 2025/3/30 20:24
# @Author: Luke Ewin
# @Blog: https://blog.lukeewin.top
import time
import evaluate
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from funasr import AutoModel
from pydub import AudioSegment
import re

chunk_size = [0, 10, 5]  # [0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 4  # number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1  # number of encoder chunks to lookback for decoder cross-attention
model_dir = ""

model = AutoModel(
    model=model_dir,
    device="cuda:0",
)


def sense_voice_infer(audio_path: str):
    res = model.generate(
        input=audio_path,
        cache={},
        language="auto",  # "zh", "en", "yue", "ja", "ko", "nospeech"
        use_itn=False,
        batch_size_s=60,
    )
    text = rich_transcription_postprocess(res[0]["text"])
    return text


def paraformer_infer(audio_path: str):
    res = model.generate(
        input=audio_path,
        cache={},
        use_itn=False,
        batch_size_s=60
    )
    text = rich_transcription_postprocess(res[0]["text"])
    return text


def paraformer_streaming_infer(audio_path: str):
    res = model.generate(
        input=audio_path,
        chunk_size=chunk_size,
        encoder_chunk_look_back=encoder_chunk_look_back,
        decoder_chunk_look_back=decoder_chunk_look_back,
    )
    return res[0]["text"]


language = 'chinese'  # chinese, japanese, korean, english, multi_language(除了中日韩英之外的其他语言)

# 获取评估方法,中日韩语言使用cer,其他语言使用wer
if language in ['chinese', 'japanese', 'korean']:
    metric = evaluate.load('utils/cer.py')
else:
    metric = evaluate.load('utils/wer.py')

# 指定测试集 scp 文件
test_wav = ""
wav_dict = {}
with open(test_wav, 'r', encoding='utf-8') as f:
    # 获取音频唯一标识
    for line in f:
        tmp_line = line.strip().split()
        audio_id = tmp_line[0].strip()
        audio_path = ''.join(tmp_line[1:]).strip()
        wav_dict[audio_id] = audio_path

test_text = ""  # test_text.txt
eval_result = ""  # 保存评估结果文件
sum_cer = 0  # 总字错率
avg_cer = 0  # 平均字错率
count_line = 0  # 行数
with open(test_text, 'r', encoding='utf-8') as f, open(eval_result, 'a', encoding='utf-8') as eval_f:
    for line in f:
        tmp_line = line.strip().split()
        audio_id = tmp_line[0].strip()
        reference_text = ''.join(tmp_line[1:]).strip()
        if audio_id in wav_dict:
            wav_path = wav_dict[audio_id]  # 获取到对应的音频路径
            start_time = time.perf_counter()
            infer_text = sense_voice_infer(audio_path=wav_path)  # 获取推理文本
            infer_text = re.sub(r'[^a-zA-Z\u4e00-\u9fff]', '', infer_text)
            reference_text = re.sub(r'[^a-zA-Z\u4e00-\u9fff]', '', reference_text)
            duration = time.perf_counter() - start_time
            audio_duration = AudioSegment.from_file(wav_path).duration_seconds
            rtf = duration/audio_duration
            # 用推理文本和标注文本进行计算
            metric.add(prediction=infer_text, references=reference_text)
            cer = metric.compute()
            print(f'{audio_id} 推理结果:{infer_text} , 标注:{reference_text} , cer:{cer} , rtf: {rtf}')
            eval_f.write(f'{audio_id} 推理结果:{infer_text} , 标注:{reference_text} , cer:{cer} , rtf: {rtf}\n')
            sum_cer = sum_cer + cer
            count_line = count_line + 1
    avg_cer = sum_cer/count_line
    print(f"平均字错率:{avg_cer}")
    eval_f.write(f'平均字错率:{avg_cer}')

上面代码需要指定test_wav,test_text和eval_result。其中test_wav指定test_wav.scp文件路径,然后test_text指定test_text.txt路径,eval_result指定保存评估文件路径。

6. 导出onnx格式模型

执行下面代码导出带量化后的onnx模型。

python 复制代码
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
from funasr import AutoModel

model = AutoModel(
    model=r"D:\Works\ASR\sichuan\model\train_50"
)

res = model.export(type="onnx", quantize=True)
print(res)

你需要根据你保存训练后的model.pt文件路径修改model值。

7. 其它

Linux服务器中源码编译安装funasr可看源码编译安装funasr

Docker部署funasr接口可看内网部署语音识别接口

相关推荐
提娜米苏2 小时前
[论文笔记] ASR is all you need: Cross-modal distillation for lip reading (2020)
论文阅读·深度学习·计算机视觉·语音识别·知识蒸馏·唇语识别
gaetoneai1 天前
当OpenAI内部命名乱成“GPT-5.1a-beta-v3-rev2”,Gateone.ai 已为你筑起一道“多模态智能的稳定防线”。
人工智能·语音识别
360智汇云1 天前
智汇云API市场:大模型流式语音识别
人工智能·语音识别·xcode
智算菩萨1 天前
大规模语音与语音对话模型:从 ASR/TTS 到情感与意图理解
人工智能·自然语言处理·语音识别
修一呀2 天前
【声音分离】多人语音分离方案:ClearVoice + MossFormer2_SS_16K 实战教程
人工智能·语音识别
咨询QQ276998852 天前
COMSOL水力压裂模型:应力-渗流-损伤一体化模拟及效率优化
语音识别
余蓝2 天前
部署语音模型CosyVoice,附多种玩法
人工智能·语言模型·transformer·语音识别·audiolm
利刃大大4 天前
【c++中间件】语音识别SDK && 二次封装
开发语言·c++·中间件·语音识别
c***97986 天前
React语音识别案例
前端·react.js·语音识别