diff --git a/inference.py b/inference.py index 4d721d0..d184925 100644 --- a/inference.py +++ b/inference.py @@ -399,7 +399,7 @@ class OmniInference: model = self.model with self.fabric.init_tensor(): - model.set_kv_cache(batch_size=2) + model.set_kv_cache(batch_size=2,device=self.device) mel, leng = load_audio(audio_path) audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device) diff --git a/server.py b/server.py index 5740613..c6e5d98 100644 --- a/server.py +++ b/server.py @@ -46,9 +46,9 @@ def create_app(): return server.server -def serve(ip='0.0.0.0', port=60808): +def serve(ip='0.0.0.0', port=60808, device='cuda:0'): - OmniChatServer(ip, port=port, run_app=True) + OmniChatServer(ip, port=port,run_app=True, device=device) if __name__ == "__main__":