python
复制代码
import gradio as gr
import os
import random
import json
import requests
import time
from openai import AzureOpenAI
# audio to text here
def audio_to_text(audio_path):
"""
audio to text here,目前是openai whisper
Parameters:
audio_path: str, 音频文件路径
Returns:
transcription.text: str, 音频转换的文本
"""
if audio_path == None or "":
return None
print(f"正在处理audio_path:{audio_path}")
client = AzureOpenAI(
api_key='',
api_version = "",
azure_endpoint="https://speech-01.openai.azure.com/"
)
audio_file= open(audio_path, "rb")
transcription = client.audio.transcriptions.create(
model="whisper",
file=audio_file
)
print(transcription.text)
return transcription.text
def chat_completions(messages, gr_states, history):
"""
chat completion here,目前是kimi free api
Parameters:
messages: openai 格式 messages
Returns:
response: dict, openai chat api返回的结果
"""
if not messages:
return gr_states, history
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + "{your refresh token here}"
}
max_retry = 5
retry = 0
while retry < max_retry:
try:
retry += 1
response = requests.post(
url="{your free kimi api deploy url here}",
headers=headers,
data=json.dumps({
"model": "kimi",
"messages": messages,
"stream": False,
# "temperature": 0.8,
}),
)
print(response.json())
content = response.json()['choices'][0]['message']['content']
if content:
gr_states["history"][-1].append(content)
history.pop()
history.append(gr_states["history"][-1])
break
except Exception as e:
print(e)
pass
if retry == max_retry:
gr_states["history"][-1].append("Connection Error: 请求失败,请重试")
print(history)
history.pop()
history.append(gr_states["history"][-1])
return gr_states, history
def process_tts(text):
"""
text to speech here
Parameters:
text: str, 待转换的文本
Returns:
path: str, 保存音频的路径
"""
url = '{your tts model url here}'
headers = {'Content-Type': 'application/json'}
data = {
"text": text,
"text_language": "zh"
}
time_stamp = time.strftime("%Y%m%d-%H%M%S")
directory = './audio_cache/'
if not os.path.exists(directory):
os.makedirs(directory)
path = directory + 'audio_' + time_stamp + '.wav'
response = requests.post(url, headers=headers, data=json.dumps(data))
print("Status Code:", response.status_code)
if response.status_code == 200:
with open(path, 'wb') as f:
f.write(response.content)
else:
print('Request failed.')
return path
def get_audio(gr_states, audio):
"""
在gradio上渲染audio组件, 更新chatbot组件
"""
response = gr_states["history"][-1][1]
print(gr_states)
if response == "Connection Error: 请求失败,请重试" or response == None:
gr_states["history"].pop()
return audio
else:
audio = process_tts(response)
return audio
def init_default_role():
"""
初始化默认角色
根据角色确定 system prompt
"""
system_prompt = "你是一只会说话的青蛙,但无论说什么都爱在最后加上'呱唧呱唧'。"
role = "一只用于演示的青蛙"
role_description = "它是一只会说话的青蛙,但无论说什么都爱在最后加上'呱唧呱唧'。"
return role, role_description, system_prompt
def get_random_role():
"""
随机获取一个角色,这里只是一个示例函数
根据角色确定 system prompt
"""
i = random.randint(0, 10)
system_prompt = "你是一只会说话的青蛙,但无论说什么都爱在最后加上'呱唧呱唧'。"
role = f"另一只用于演示的{i}号青蛙"
role_description = "它也是一只会说话的青蛙,但无论说什么都爱在最后加上'呱唧呱唧'。"
return role, role_description, system_prompt
def format_messages(user_message, gr_states, history):
"""prepare the request data [messages] for the chatbot
Parameters:
user_message: str, 用户输入的消息
gr_states: dict, {"system_prompt": str, "hisotry": List, "user_prompt": str}
history: list, 聊天记录,一个嵌套列表: [["用户消息", "bot回复"],["用户消息", "bot回复"]]
"""
messages = [
{
"role": "system",
"content": gr_states["system_prompt"],
},
]
history.append([user_message, None])
if len(user_message) > 0:
gr_states["history"].append([user_message])
for [usr, bot] in history:
messages.append({
"role": "user",
"content": usr
})
if bot:
messages.append({
"role": "assistant",
"content": bot
})
return messages, gr_states, history
else:
return None, gr_states, history
def set_up(gr_states):
"""
maybe 随机切换一个角色
"""
role_name, role_description, system_prompt = get_random_role()
gr_states = {"system_prompt": system_prompt, "history":[]}
role_info_display = f''' # {role_name}
{role_description}
'''
history = []
return history, gr_states, role_info_display, None
with gr.Blocks(gr.themes.Soft()) as demo:
demo.title = 'Takway.AI'
gr.Markdown('''<center><font size=6>Takway.AI </font></center>''')
role_name, role_description, system_prompt = init_default_role()
gr_states = gr.State({"system_prompt": system_prompt, "history":[]})
messages = gr.State(None)
with gr.Tab(label='demo'):
with gr.Row():
role_info_display = gr.Markdown(f''' # {role_name}
{role_description}
''')
with gr.Row():
with gr.Column(scale = 7):
with gr.Row():
chatbot = gr.Chatbot(label='聊天界面', value=[], render_markdown=False, height=500, visible=True)
with gr.Row():
user_prompt = gr.Textbox(label='对话输入框(按Enter发送消息)', interactive=True, visible=True)
# input_audio = gr.Audio(sources=['microphone'])
input_audio = gr.Audio(label = "语音输入框", sources=['microphone', 'upload'], type="filepath")
with gr.Column(scale=3):
with gr.Row():
change_btn = gr.Button("随机换一个角色")
with gr.Row():
audio = gr.Audio(label = "output", interactive=False, autoplay=True)
user_prompt.submit(
format_messages, [user_prompt, gr_states, chatbot], [messages, gr_states, chatbot]).then(
chat_completions, [messages, gr_states, chatbot], [gr_states, chatbot]).then(
get_audio, [gr_states, audio], audio
)
input_audio.change(audio_to_text, input_audio, user_prompt)
change_btn.click(set_up, gr_states, [chatbot, gr_states, role_info_display, audio])
demo.launch(server_name='0.0.0.0', server_port=9877, share=True)