mirror of
https://github.com/gpt-omni/mini-omni
synced 2024-11-28 23:17:38 +00:00
82 lines
2.4 KiB
Python
82 lines
2.4 KiB
Python
"""A simple web interactive chat demo based on gradio."""
|
|
|
|
import os
|
|
import time
|
|
import gradio as gr
|
|
import base64
|
|
import numpy as np
|
|
import requests
|
|
|
|
|
|
API_URL = os.getenv("API_URL", None)
|
|
client = None
|
|
|
|
if API_URL is None:
|
|
from inference import OmniInference
|
|
omni_client = OmniInference('./checkpoint', 'cuda:0')
|
|
omni_client.warm_up()
|
|
|
|
|
|
OUT_CHUNK = 4096
|
|
OUT_RATE = 24000
|
|
OUT_CHANNELS = 1
|
|
|
|
|
|
def process_audio(audio):
|
|
filepath = audio
|
|
print(f"filepath: {filepath}")
|
|
if filepath is None:
|
|
return
|
|
|
|
cnt = 0
|
|
if API_URL is not None:
|
|
with open(filepath, "rb") as f:
|
|
data = f.read()
|
|
base64_encoded = str(base64.b64encode(data), encoding="utf-8")
|
|
files = {"audio": base64_encoded}
|
|
tik = time.time()
|
|
with requests.post(API_URL, json=files, stream=True) as response:
|
|
try:
|
|
for chunk in response.iter_content(chunk_size=OUT_CHUNK):
|
|
if chunk:
|
|
# Convert chunk to numpy array
|
|
if cnt == 0:
|
|
print(f"first chunk time cost: {time.time() - tik:.3f}")
|
|
cnt += 1
|
|
audio_data = np.frombuffer(chunk, dtype=np.int16)
|
|
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
|
|
yield OUT_RATE, audio_data.astype(np.int16)
|
|
|
|
except Exception as e:
|
|
print(f"error: {e}")
|
|
else:
|
|
for chunk in omni_client.run_AT_batch_stream(filepath):
|
|
# Convert chunk to numpy array
|
|
if cnt == 0:
|
|
print(f"first chunk time cost: {time.time() - tik:.3f}")
|
|
cnt += 1
|
|
audio_data = np.frombuffer(chunk, dtype=np.int16)
|
|
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
|
|
yield OUT_RATE, audio_data.astype(np.int16)
|
|
|
|
|
|
def main(port=None):
|
|
|
|
demo = gr.Interface(
|
|
process_audio,
|
|
inputs=gr.Audio(type="filepath", label="Microphone"),
|
|
outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
|
|
title="Chat Mini-Omni Demo",
|
|
live=True,
|
|
)
|
|
if port is not None:
|
|
demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port)
|
|
else:
|
|
demo.queue().launch()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import fire
|
|
|
|
fire.Fire(main)
|