Merge pull request #3558 from cheahjs/refac/reduce-startup-mem-usage

refac: reduce startup memory usage
This commit is contained in:
Timothy Jaeryang Baek 2024-06-30 17:46:27 -07:00 committed by GitHub
commit 5c6e30cb5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 8 additions and 6 deletions

View File

@ -14,7 +14,6 @@ from fastapi import (
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel from pydantic import BaseModel
import uuid import uuid
@ -277,6 +276,8 @@ def transcribe(
f.close() f.close()
if app.state.config.STT_ENGINE == "": if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = { whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL, "model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type, "device": whisper_device_type,

View File

@ -12,7 +12,6 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (

View File

@ -48,8 +48,6 @@ import mimetypes
import uuid import uuid
import json import json
import sentence_transformers
from apps.webui.models.documents import ( from apps.webui.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
@ -190,6 +188,8 @@ def update_embedding_model(
update_model: bool = False, update_model: bool = False,
): ):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model), get_model_path(embedding_model, update_model),
device=DEVICE_TYPE, device=DEVICE_TYPE,
@ -204,6 +204,8 @@ def update_reranking_model(
update_model: bool = False, update_model: bool = False,
): ):
if reranking_model: if reranking_model:
import sentence_transformers
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model), get_model_path(reranking_model, update_model),
device=DEVICE_TYPE, device=DEVICE_TYPE,

View File

@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor): class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any embedding_function: Any
@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents] [(query, doc.page_content) for doc in documents]
) )
else: else:
from sentence_transformers import util
query_embedding = self.embedding_function(query) query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function( document_embedding = self.embedding_function(
[doc.page_content for doc in documents] [doc.page_content for doc in documents]