mini-omni/webui/omni_gradio.py
2024-08-29 19:10:03 +08:00

82 lines
2.4 KiB
Python

"""A simple web interactive chat demo based on gradio."""
import os
import time
import gradio as gr
import base64
import numpy as np
import requests
API_URL = os.getenv("API_URL", None)
client = None
if API_URL is None:
from inference import OmniInference
omni_client = OmniInference('./checkpoint', 'cuda:0')
omni_client.warm_up()
OUT_CHUNK = 4096
OUT_RATE = 24000
OUT_CHANNELS = 1
def process_audio(audio):
filepath = audio
print(f"filepath: {filepath}")
if filepath is None:
return
cnt = 0
if API_URL is not None:
with open(filepath, "rb") as f:
data = f.read()
base64_encoded = str(base64.b64encode(data), encoding="utf-8")
files = {"audio": base64_encoded}
tik = time.time()
with requests.post(API_URL, json=files, stream=True) as response:
try:
for chunk in response.iter_content(chunk_size=OUT_CHUNK):
if chunk:
# Convert chunk to numpy array
if cnt == 0:
print(f"first chunk time cost: {time.time() - tik:.3f}")
cnt += 1
audio_data = np.frombuffer(chunk, dtype=np.int16)
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
yield OUT_RATE, audio_data.astype(np.int16)
except Exception as e:
print(f"error: {e}")
else:
for chunk in omni_client.run_AT_batch_stream(filepath):
# Convert chunk to numpy array
if cnt == 0:
print(f"first chunk time cost: {time.time() - tik:.3f}")
cnt += 1
audio_data = np.frombuffer(chunk, dtype=np.int16)
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
yield OUT_RATE, audio_data.astype(np.int16)
def main(port=None):
demo = gr.Interface(
process_audio,
inputs=gr.Audio(type="filepath", label="Microphone"),
outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
title="Chat Mini-Omni Demo",
live=True,
)
if port is not None:
demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port)
else:
demo.queue().launch()
if __name__ == "__main__":
import fire
fire.Fire(main)