火山引擎 大模型语音合成双向流式API-python demo

在播放大模型生成内容时,可以发现时变生成变播放。这里可以使用火山引擎的双向流语音合成,官方没有提供python版本的demo,且官方文档实际上表述的并不清晰,所以我在阅读go语言版本后,自己写了一个,提供给大家。

官方文档 https://www.volcengine.com/docs/6561/1329505

代码

需要自行替换 APP_KEY和ACCESS_KEY

protocol.py

python 复制代码
import json
import struct


class Event:
    NONE = 0
    START_CONNECTION = 1
    FINISH_CONNECTION = 2
    CONNECTION_STARTED = 50
    CONNECTION_FAILED = 51
    CONNECTION_FINISHED = 52
    START_SESSION = 100
    FINISH_SESSION = 102
    SESSION_STARTED = 150
    SESSION_FINISHED = 152
    SESSION_FAILED = 153
    TASK_REQUEST = 200
    TTS_SENTENCE_START = 350
    TTS_SENTENCE_END = 351
    TTS_RESPONSE = 352


def create_start_connection_frame():
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.START_CONNECTION))  # event_type
    payload = json.dumps({}).encode()
    payload_len = struct.pack(">I", len(payload))
    return bytes(frame + payload_len + payload)


def create_finish_connection_frame() -> bytes:
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.FINISH_CONNECTION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(b'{}')))  # payload_len
    frame.extend(b'{}')  # payload
    return bytes(frame)


def create_start_session_frame(session_id: str, speaker: str):
    b_meta_data_json = json.dumps({
        "event": 100,
        "req_params": {
            "speaker": speaker,
            "audio_params": {"format": "mp3", "sample_rate": 24000},
        },
    }, ensure_ascii=False).encode()
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.START_SESSION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(session_id.encode())))
    frame.extend(session_id.encode())  # session_id

    # meta_data_len
    frame.extend(struct.pack(">I", len(b_meta_data_json)))
    frame.extend(b_meta_data_json)
    return bytes(frame)


def create_finish_session_frame(session_id: str):
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.FINISH_SESSION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(session_id.encode())))
    frame.extend(session_id.encode())  # session_id
    frame.extend(struct.pack(">I", len(b'{}')))  # payload_len
    frame.extend(b'{}')  # payload
    return bytes(frame)


def create_task_request_frame(chunk: str, session_id: str):
    b_chunk_json = json.dumps(
        {
            "event": Event.TASK_REQUEST,
            "req_params": {
                "text": chunk,
            },
        }
    ).encode()
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.TASK_REQUEST))  # event_type
    session_id_bytes = session_id.encode()
    session_id_len = struct.pack(">I", len(session_id_bytes))
    frame.extend(session_id_len)
    frame.extend(session_id_bytes)
    frame.extend(struct.pack(">I", len(b_chunk_json)))
    frame.extend(b_chunk_json)
    return bytes(frame)


def parse_frame(frame):
    if not isinstance(frame, bytes):
        raise ValueError(f"frame is not bytes: {frame}")

    header = frame[:4]
    version = header[0] >> 4
    header_size = (header[0] & 0x0F) * 4
    message_type = header[1] >> 4
    flags = header[1] & 0x0F
    serialization_method = header[2] >> 4
    compression_method = header[2] & 0x0F

    event = struct.unpack(">I", frame[4:8])[0]

    payload_start = header_size
    if flags & 0x04:  # Check if event number is present
        payload_start += 4

    if message_type in [0b0001, 0b1001, 0b1011]:  # Full request/response or Audio-only
        session_id_len = struct.unpack(
            ">I", frame[payload_start: payload_start + 4])[0]
        session_id = frame[payload_start +
                           4: payload_start + 4 + session_id_len].decode()
        payload_start += 4 + session_id_len
    else:
        session_id = None

    payload_len = struct.unpack(
        ">I", frame[payload_start: payload_start + 4])[0]
    payload = frame[payload_start + 4: payload_start + 4 + payload_len]

    return {
        "version": version,
        "message_type": message_type,
        "serialization_method": serialization_method,
        "compression_method": compression_method,
        "event": event,
        "session_id": session_id,
        "payload": payload,
    }

client.py

python 复制代码
import asyncio
import logging
import uuid
from typing import AsyncGenerator

import websockets



class TtsClient:
    DEFAULT_API_ENDPOINT = "wss://openspeech.bytedance.com/api/v3/tts/bidirection"

    def get_headers(self):
        return {
            "X-Api-App-Key": "your VOL_TTS_APP_KEY",
            "X-Api-Access-Key": "your VOL_TTS_ACCESS_KEY,
            "X-Api-Resource-Id": "volc.service_type.10029",
            "X-Api-Request-Id": uuid.uuid4(),
        }

    async def send_task_frame(
        self,
        session_id: str,
        text_generator: AsyncGenerator[str, None],
        ws: websockets.WebSocketClientProtocol,
    ):
        async for chunk in text_generator:
            task_frame = protocol.send_task_frame(chunk=chunk, session_id=session_id)
            await ws.send(task_frame)

        await ws.send(protocol.finish_session_frame(session_id))

    async def receive_response(self, ws: websockets.WebSocketClientProtocol):
        while True:
            response = await ws.recv()
            frame = protocol.parse_frame(response)
            match frame["event"]:
                case protocol.Event.TTS_RESPONSE:
                    yield frame["payload"]
                case protocol.Event.SESSION_FINISHED | protocol.Event.FINISH_CONNECTION:
                    break

    async def a_duplex_tts(
        self, message_id: str, text_generator: AsyncGenerator[str, None]
    ) -> AsyncGenerator[bytes, None]:
        async with websockets.connect(
            self.DEFAULT_API_ENDPOINT, extra_headers=self.get_headers()
        ) as ws:
            try:
                await ws.send(protocol.start_connection_frame())
                response = await ws.recv()
                logging.debug(protocol.parse_frame(response))

                start_session_frame = protocol.start_session_frame(
                    session_id=message_id, speaker=settings.VOL_TTS_SPEAKER
                )
                await ws.send(start_session_frame)
                response = await ws.recv()
                logging.debug(protocol.parse_frame(response))

                send_task = asyncio.create_task(
                    self.send_task_frame(message_id, text_generator, ws)
                )

                async for audio_chunk in self.receive_response(ws):
                    yield audio_chunk

                # wait for send task to finish
                await send_task
                await ws.send(protocol.finish_session_frame(message_id))

            except Exception as e:
                logging.error(e, exc_info=True)
            finally:
                await ws.send(protocol.finish_connection_frame())

test

python 复制代码
from typing import AsyncGenerator
import pytest
from ai_bot.tts.client import TtsClient
from langchain_openai import ChatOpenAI


llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.1,
)


@pytest.mark.asyncio
async def test_run():
    client = TtsClient()

    async def a_text_generator() -> AsyncGenerator[str, None]:
        async for chunk in llm.astream("你好"):
            yield str(chunk.content)

    combined_audio = bytearray()
    async for chunk in client.a_duplex_tts(
        message_id="test_session_id", text_generator=a_text_generator()
    ):
        combined_audio.extend(chunk)
    
    with open("combined_audio.wav", "wb") as audio_file:
        audio_file.write(combined_audio)
相关推荐
深度学习lover2 分钟前
[项目代码] YOLOv8 遥感航拍飞机和船舶识别 [目标检测]
python·yolo·目标检测·计算机视觉·遥感航拍飞机和船舶识别
水木流年追梦31 分钟前
【python因果库实战10】为何需要因果分析
开发语言·python
m0_675988231 小时前
Leetcode2545:根据第 K 场考试的分数排序
python·算法·leetcode
凡人的AI工具箱3 小时前
每天40分玩转Django:Django测试
数据库·人工智能·后端·python·django·sqlite
qyq13 小时前
Django框架与ORM框架
后端·python·django
企业软文推广3 小时前
企业如何选择媒体发稿平台及相关事项?媒介盒子分享
python
JM_life4 小时前
你的第一个博客-第一弹
python
No0d1es4 小时前
GESP CCF python二级编程等级考试认证真题 2024年12月
开发语言·python·青少年编程·gesp·ccf·二级
大霞上仙5 小时前
selenium 在已打开浏览器上继续调试
python·selenium·测试工具
CodeClimb5 小时前
【华为OD-E卷-开心消消乐 100分(python、java、c++、js、c)】
java·python·华为od