diff --git a/inference.py b/inference.py index 793f1b2..12844b9 100644 --- a/inference.py +++ b/inference.py @@ -480,6 +480,7 @@ class OmniInference: with self.fabric.init_tensor(): model.set_kv_cache(batch_size=2, device=self.device) + # Load the audio and process it using Whisper mel, leng = load_audio(audio_path) audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device) diff --git a/server.py b/server.py index 7bff98b..0e01b6e 100644 --- a/server.py +++ b/server.py @@ -72,6 +72,7 @@ def serve(ip='0.0.0.0', port=60808, device=None): OmniChatServer(ip, port=port, run_app=True, device=device) + if __name__ == "__main__": import fire