diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index c0495f3fa..f6b57fe31 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -180,13 +180,13 @@ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_ def update_embedding_model( embedding_model: str, - update_model: bool = False, + auto_update: bool = False, ): if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": import sentence_transformers app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( - get_model_path(embedding_model, update_model), + get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) @@ -196,17 +196,18 @@ def update_embedding_model( def update_reranking_model( reranking_model: str, - update_model: bool = False, + auto_update: bool = False, ): if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): - class Colbert: + class ColBERT: def __init__(self, name) -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to( - self.device - ) + self.ckpt = Checkpoint( + get_model_path(name, auto_update), + colbert_config=ColBERTConfig(), + ).to(self.device) pass def calculate_similarity_scores( @@ -264,13 +265,13 @@ def update_reranking_model( return scores - app.state.sentence_transformer_rf = Colbert(reranking_model) + app.state.sentence_transformer_rf = ColBERT(reranking_model) else: import sentence_transformers try: app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - get_model_path(reranking_model, update_model), + get_model_path(reranking_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, )