diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 8843f376f..f866d867f 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -14,7 +14,6 @@ from fastapi import ( from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware -from faster_whisper import WhisperModel from pydantic import BaseModel import uuid @@ -277,6 +276,8 @@ def transcribe( f.close() if app.state.config.STT_ENGINE == "": + from faster_whisper import WhisperModel + whisper_kwargs = { "model_size_or_path": WHISPER_MODEL, "device": whisper_device_type, diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 8f1a08e04..24542ee93 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -12,7 +12,6 @@ from fastapi import ( Form, ) from fastapi.middleware.cors import CORSMiddleware -from faster_whisper import WhisperModel from constants import ERROR_MESSAGES from utils.utils import ( diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 7c6974535..5e4ec03c3 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -48,8 +48,6 @@ import mimetypes import uuid import json -import sentence_transformers - from apps.webui.models.documents import ( Documents, DocumentForm, @@ -190,6 +188,8 @@ def update_embedding_model( update_model: bool = False, ): if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": + import sentence_transformers + app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, @@ -204,6 +204,8 @@ def update_reranking_model( update_model: bool = False, ): if reranking_model: + import sentence_transformers + app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( get_model_path(reranking_model, update_model), device=DEVICE_TYPE, diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7b4324d9a..3a3dad4a2 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.callbacks import Callbacks from langchain_core.pydantic_v1 import Extra -from sentence_transformers import util - class RerankCompressor(BaseDocumentCompressor): embedding_function: Any @@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor): [(query, doc.page_content) for doc in documents] ) else: + from sentence_transformers import util + query_embedding = self.embedding_function(query) document_embedding = self.embedding_function( [doc.page_content for doc in documents]