From a48ac6a20973d0d433e49ff5176e7d0704d84d14 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Mon, 1 Jul 2024 08:13:56 +0800 Subject: [PATCH] refac: lazily load sentence_transformers to reduce start up memory usage --- backend/apps/rag/main.py | 6 ++++-- backend/apps/rag/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 7c6974535..5e4ec03c3 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -48,8 +48,6 @@ import mimetypes import uuid import json -import sentence_transformers - from apps.webui.models.documents import ( Documents, DocumentForm, @@ -190,6 +188,8 @@ def update_embedding_model( update_model: bool = False, ): if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": + import sentence_transformers + app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -204,6 +204,8 @@ def update_reranking_model( update_model: bool = False, ): if reranking_model: + import sentence_transformers + app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( get_model_path(reranking_model, update_model), device=DEVICE_TYPE, diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7b4324d9a..3a3dad4a2 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.callbacks import Callbacks from langchain_core.pydantic_v1 import Extra -from sentence_transformers import util - class RerankCompressor(BaseDocumentCompressor): embedding_function: Any @@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor): [(query, doc.page_content) for doc in documents] ) else: + from sentence_transformers import util + query_embedding = self.embedding_function(query) document_embedding = self.embedding_function( [doc.page_content for doc in documents]