Merge pull request #33 from wuhongsheng/dev

feat:增加device 参数
This commit is contained in:
mini-omni 2024-09-07 14:04:59 +03:00 committed by GitHub
commit 56a416f222
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -399,7 +399,7 @@ class OmniInference:
model = self.model
with self.fabric.init_tensor():
model.set_kv_cache(batch_size=2)
model.set_kv_cache(batch_size=2,device=self.device)
mel, leng = load_audio(audio_path)
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)

View File

@ -46,9 +46,9 @@ def create_app():
return server.server
def serve(ip='0.0.0.0', port=60808):
def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
OmniChatServer(ip, port=port, run_app=True)
OmniChatServer(ip, port=port,run_app=True, device=device)
if __name__ == "__main__":