火山引擎大模型语音合成双向流式-python demo

在播放大模型生成内容时,可以发现时变生成变播放。这里可以使用火山引擎的双向流语音合成,官方没有提供python版本的demo,且官方文档实际上表述的并不清晰,所以我在阅读go语言版本后,自己写了一个,提供给大家。

官方文档 https://www.volcengine.com/docs/6561/1329505

代码

需要自行替换 APP_KEY和ACCESS_KEY

protocol.py

python 复制代码
import json
import struct


class Event:
    NONE = 0
    START_CONNECTION = 1
    FINISH_CONNECTION = 2
    CONNECTION_STARTED = 50
    CONNECTION_FAILED = 51
    CONNECTION_FINISHED = 52
    START_SESSION = 100
    FINISH_SESSION = 102
    SESSION_STARTED = 150
    SESSION_FINISHED = 152
    SESSION_FAILED = 153
    TASK_REQUEST = 200
    TTS_SENTENCE_START = 350
    TTS_SENTENCE_END = 351
    TTS_RESPONSE = 352


def create_start_connection_frame():
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.START_CONNECTION))  # event_type
    payload = json.dumps({}).encode()
    payload_len = struct.pack(">I", len(payload))
    return bytes(frame + payload_len + payload)


def create_finish_connection_frame() -> bytes:
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.FINISH_CONNECTION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(b'{}')))  # payload_len
    frame.extend(b'{}')  # payload
    return bytes(frame)


def create_start_session_frame(session_id: str, speaker: str):
    b_meta_data_json = json.dumps({
        "event": 100,
        "req_params": {
            "speaker": speaker,
            "audio_params": {"format": "mp3", "sample_rate": 24000},
        },
    }, ensure_ascii=False).encode()
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.START_SESSION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(session_id.encode())))
    frame.extend(session_id.encode())  # session_id

    # meta_data_len
    frame.extend(struct.pack(">I", len(b_meta_data_json)))
    frame.extend(b_meta_data_json)
    return bytes(frame)


def create_finish_session_frame(session_id: str):
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.FINISH_SESSION))  # event_type
    # session_id_len
    frame.extend(struct.pack(">I", len(session_id.encode())))
    frame.extend(session_id.encode())  # session_id
    frame.extend(struct.pack(">I", len(b'{}')))  # payload_len
    frame.extend(b'{}')  # payload
    return bytes(frame)


def create_task_request_frame(chunk: str, session_id: str):
    b_chunk_json = json.dumps(
        {
            "event": Event.TASK_REQUEST,
            "req_params": {
                "text": chunk,
            },
        }
    ).encode()
    frame = bytearray()
    frame.append(0b0001_0001)  # header
    frame.append(0b0001_0100)  # event_number
    frame.append(0b0001_0000)  # serialization
    frame.append(0b0000_0000)  # reserved
    frame.extend(struct.pack(">i", Event.TASK_REQUEST))  # event_type
    session_id_bytes = session_id.encode()
    session_id_len = struct.pack(">I", len(session_id_bytes))
    frame.extend(session_id_len)
    frame.extend(session_id_bytes)
    frame.extend(struct.pack(">I", len(b_chunk_json)))
    frame.extend(b_chunk_json)
    return bytes(frame)


def parse_frame(frame):
    if not isinstance(frame, bytes):
        raise ValueError(f"frame is not bytes: {frame}")

    header = frame[:4]
    version = header[0] >> 4
    header_size = (header[0] & 0x0F) * 4
    message_type = header[1] >> 4
    flags = header[1] & 0x0F
    serialization_method = header[2] >> 4
    compression_method = header[2] & 0x0F

    event = struct.unpack(">I", frame[4:8])[0]

    payload_start = header_size
    if flags & 0x04:  # Check if event number is present
        payload_start += 4

    if message_type in [0b0001, 0b1001, 0b1011]:  # Full request/response or Audio-only
        session_id_len = struct.unpack(
            ">I", frame[payload_start: payload_start + 4])[0]
        session_id = frame[payload_start +
                           4: payload_start + 4 + session_id_len].decode()
        payload_start += 4 + session_id_len
    else:
        session_id = None

    payload_len = struct.unpack(
        ">I", frame[payload_start: payload_start + 4])[0]
    payload = frame[payload_start + 4: payload_start + 4 + payload_len]

    return {
        "version": version,
        "message_type": message_type,
        "serialization_method": serialization_method,
        "compression_method": compression_method,
        "event": event,
        "session_id": session_id,
        "payload": payload,
    }

client.py

python 复制代码
import asyncio
import uuid
from typing import AsyncGenerator, Generator

import websockets

import protocol


class TtsClient:
    API_ENDPOINT = "wss://openspeech.bytedance.com/api/v3/tts/bidirection"
    APP_KEY = "*"
    ACCESS_KEY = "*"
    SPEAKER = "zh_female_wanqudashu_moon_bigtts"

    def get_headers(self):
        return {
            "X-Api-App-Key": self.APP_KEY,
            "X-Api-Access-Key": self.ACCESS_KEY,
            "X-Api-Resource-Id": "volc.service_type.10029",
            "X-Api-Request-Id": uuid.uuid4(),
        }

    async def send_task_frame(
        self,
        session_id: str,
        text_generator: Generator[str, None, None],
        ws: websockets.WebSocketClientProtocol,
    ):
        for chunk in text_generator:
            task_frame = protocol.create_task_request_frame(
                chunk=chunk, session_id=session_id
            )
            await ws.send(task_frame)

        await ws.send(protocol.create_finish_session_frame(session_id))

    async def receive_response(self, ws: websockets.WebSocketClientProtocol):
        while True:
            response = await ws.recv()
            frame = protocol.parse_frame(response)
            match frame["event"]:
                case protocol.Event.TTS_RESPONSE:
                    yield frame["payload"]
                case protocol.Event.SESSION_FINISHED | protocol.Event.FINISH_CONNECTION:
                    break

    async def a_bindirect_tts(
        self, session_id: str, text_generator: Generator[str, None, None]
    ) -> AsyncGenerator[bytes, None]:

        async with websockets.connect(
            self.API_ENDPOINT, extra_headers=self.get_headers()
        ) as ws:

            try:
                await ws.send(protocol.create_start_connection_frame())
                response = await ws.recv()
                print(protocol.parse_frame(response))
                session_frame = protocol.create_start_session_frame(
                    session_id=session_id, speaker=self.SPEAKER
                )
                await ws.send(session_frame)
                response = await ws.recv()
                print(protocol.parse_frame(response))
                send_task = asyncio.create_task(
                    self.send_task_frame(session_id, text_generator, ws)
                )

                async for audio_chunk in self.receive_response(ws):
                    yield audio_chunk

                await send_task
            except Exception as e:
                print(e)
            finally:
                await ws.send(protocol.create_finish_session_frame(session_id))
                await ws.send(protocol.create_finish_connection_frame())

test

python 复制代码
def test_generate():
    for chunk in ["hello world", "世界你好"]:
        yield chunk


@pytest.mark.asyncio
async def test_run():
    client = TtsClient()
    generator = test_generate()

    combined_audio = bytearray()
    async for chunk in client.a_bindirect_tts(
        session_id="test_session_id", text_generator=generator
    ):
        combined_audio.extend(chunk)
    # 将合并后的音频保存到文件
    with open("combined_audio.wav", "wb") as audio_file:
        audio_file.write(combined_audio)
相关推荐
q567315239 分钟前
在 Bash 中获取 Python 模块变量列
开发语言·python·bash
是萝卜干呀10 分钟前
Backend - Python 爬取网页数据并保存在Excel文件中
python·excel·table·xlwt·爬取网页数据
代码欢乐豆11 分钟前
数据采集之selenium模拟登录
python·selenium·测试工具
许野平34 分钟前
Rust: 利用 chrono 库实现日期和字符串互相转换
开发语言·后端·rust·字符串·转换·日期·chrono
也无晴也无风雨38 分钟前
在JS中, 0 == [0] 吗
开发语言·javascript
狂奔solar1 小时前
yelp数据集上识别潜在的热门商家
开发语言·python
Tassel_YUE1 小时前
网络自动化04:python实现ACL匹配信息(主机与主机信息)
网络·python·自动化
聪明的墨菲特i1 小时前
Python爬虫学习
爬虫·python·学习
blammmp1 小时前
Java:数据结构-枚举
java·开发语言·数据结构
何曾参静谧2 小时前
「C/C++」C/C++ 指针篇 之 指针运算
c语言·开发语言·c++