From a97e7cd41acc9d20e48528a030d976092f5dd26d Mon Sep 17 00:00:00 2001 From: wuhongsheng <664116298@qq.com> Date: Fri, 6 Sep 2024 10:56:48 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=A2=9E=E5=8A=A0device=20=E5=8F=82?= =?UTF-8?q?=E6=95=B0=20fix:=20set=5Fkv=5Fcache=E4=BD=BF=E7=94=A8=E9=BB=98?= =?UTF-8?q?=E8=AE=A4device=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inference.py | 2 +- server.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/inference.py b/inference.py index 4d721d0..d184925 100644 --- a/inference.py +++ b/inference.py @@ -399,7 +399,7 @@ class OmniInference: model = self.model with self.fabric.init_tensor(): - model.set_kv_cache(batch_size=2) + model.set_kv_cache(batch_size=2,device=self.device) mel, leng = load_audio(audio_path) audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device) diff --git a/server.py b/server.py index 5740613..c6e5d98 100644 --- a/server.py +++ b/server.py @@ -46,9 +46,9 @@ def create_app(): return server.server -def serve(ip='0.0.0.0', port=60808): +def serve(ip='0.0.0.0', port=60808, device='cuda:0'): - OmniChatServer(ip, port=port, run_app=True) + OmniChatServer(ip, port=port,run_app=True, device=device) if __name__ == "__main__":