From 17c684369e9a1432ce21037d8f0975cfd753457e Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Mon, 1 Jul 2024 08:13:02 +0800 Subject: [PATCH 1/2] refac: lazily load faster_whisper to reduce start up memory usage --- backend/apps/audio/main.py | 3 ++- backend/apps/images/main.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 8843f376f..f866d867f 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -14,7 +14,6 @@ from fastapi import ( from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware -from faster_whisper import WhisperModel from pydantic import BaseModel import uuid @@ -277,6 +276,8 @@ def transcribe( f.close() if app.state.config.STT_ENGINE == "": + from faster_whisper import WhisperModel + whisper_kwargs = { "model_size_or_path": WHISPER_MODEL, "device": whisper_device_type, diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 8f1a08e04..24542ee93 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -12,7 +12,6 @@ from fastapi import ( Form, ) from fastapi.middleware.cors import CORSMiddleware -from faster_whisper import WhisperModel from constants import ERROR_MESSAGES from utils.utils import ( From a48ac6a20973d0d433e49ff5176e7d0704d84d14 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Mon, 1 Jul 2024 08:13:56 +0800 Subject: [PATCH 2/2] refac: lazily load sentence_transformers to reduce start up memory usage --- backend/apps/rag/main.py | 6 ++++-- backend/apps/rag/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 7c6974535..5e4ec03c3 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -48,8 +48,6 @@ import mimetypes import uuid import json -import sentence_transformers - from apps.webui.models.documents import ( Documents, DocumentForm, @@ -190,6 +188,8 @@ def update_embedding_model( update_model: bool = False, ): if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": + import sentence_transformers + app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -204,6 +204,8 @@ def update_reranking_model( update_model: bool = False, ): if reranking_model: + import sentence_transformers + app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( get_model_path(reranking_model, update_model), device=DEVICE_TYPE, diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7b4324d9a..3a3dad4a2 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.callbacks import Callbacks from langchain_core.pydantic_v1 import Extra -from sentence_transformers import util - class RerankCompressor(BaseDocumentCompressor): embedding_function: Any @@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor): [(query, doc.page_content) for doc in documents] ) else: + from sentence_transformers import util + query_embedding = self.embedding_function(query) document_embedding = self.embedding_function( [doc.page_content for doc in documents]