From c9c9660459e9bb98b6a58e66c8123bfff53cb04e Mon Sep 17 00:00:00 2001 From: Steven Kreitzer Date: Thu, 25 Apr 2024 07:49:59 -0500 Subject: [PATCH] fix: address comment in pr #1687 --- backend/apps/ollama/main.py | 4 -- backend/apps/rag/main.py | 80 ++++++++++++++++++++----------------- backend/apps/rag/utils.py | 43 +++++++++++++++++++- backend/config.py | 8 ++++ 4 files changed, 92 insertions(+), 43 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index aeac6622d..9258efa66 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)): return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} -def get_ollama_endpoint(url_idx: int = 0): - return app.state.OLLAMA_BASE_URLS[url_idx] - - class UrlUpdateForm(BaseModel): urls: List[str] diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 4a7ff7baf..de77eeadb 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -39,8 +39,6 @@ import json import sentence_transformers -from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm - from apps.web.models.documents import ( Documents, DocumentForm, @@ -48,6 +46,7 @@ from apps.web.models.documents import ( ) from apps.rag.utils import ( + get_model_path, query_embeddings_doc, query_embeddings_function, query_embeddings_collection, @@ -60,6 +59,7 @@ from utils.misc import ( extract_folders_after_data_docs, ) from utils.utils import get_current_user, get_admin_user + from config import ( SRC_LOG_LEVELS, UPLOAD_DIR, @@ -68,8 +68,10 @@ from config import ( RAG_RELEVANCE_THRESHOLD, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, @@ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) app = FastAPI() - app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP - app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL @@ -104,27 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.PDF_EXTRACT_IMAGES = False -if app.state.RAG_EMBEDDING_ENGINE == "": - app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( - app.state.RAG_EMBEDDING_MODEL, - device=DEVICE_TYPE, - trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - ) -else: - app.state.sentence_transformer_ef = None -if not app.state.RAG_RERANKING_MODEL == "": - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - app.state.RAG_RERANKING_MODEL, - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - ) -else: - app.state.sentence_transformer_rf = None +def update_embedding_model( + embedding_model: str, + update_model: bool = False, +): + if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": + app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( + get_model_path(embedding_model, update_model), + device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + ) + else: + app.state.sentence_transformer_ef = None +def update_reranking_model( + reranking_model: str, + update_model: bool = False, +): + if reranking_model: + app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( + get_model_path(reranking_model, update_model), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + ) + else: + app.state.sentence_transformer_rf = None + + +update_embedding_model( + app.state.RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, +) + +update_reranking_model( + app.state.RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, +) + origins = ["*"] + app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -200,15 +221,7 @@ async def update_embedding_config( app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_KEY = form_data.openai_config.key - app.state.sentence_transformer_ef = None - else: - app.state.sentence_transformer_ef = ( - sentence_transformers.SentenceTransformer( - app.state.RAG_EMBEDDING_MODEL, - device=DEVICE_TYPE, - trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - ) - ) + update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) return { "status": True, @@ -219,7 +232,6 @@ async def update_embedding_config( "key": app.state.OPENAI_API_KEY, }, } - except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( @@ -242,13 +254,7 @@ async def update_reranking_config( try: app.state.RAG_RERANKING_MODEL = form_data.reranking_model - if app.state.RAG_RERANKING_MODEL == "": - app.state.sentence_transformer_rf = None - else: - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - app.state.RAG_RERANKING_MODEL, - device=DEVICE_TYPE, - ) + update_reranking_model(app.state.RAG_RERANKING_MODEL, True) return { "status": True, diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index f88335a3a..cceec5f80 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,3 +1,4 @@ +import os import logging import requests @@ -8,6 +9,8 @@ from apps.ollama.main import ( GenerateEmbeddingsForm, ) +from huggingface_hub import snapshot_download + from langchain_core.documents import Document from langchain_community.retrievers import BM25Retriever from langchain.retrievers import ( @@ -282,8 +285,6 @@ def rag_messages( extracted_collections.extend(collection) - log.debug(f"relevant_contexts: {relevant_contexts}") - context_string = "" for context in relevant_contexts: items = context["documents"][0] @@ -319,6 +320,44 @@ def rag_messages( return messages +def get_model_path(model: str, update_model: bool = False): + # Construct huggingface_hub kwargs with local_files_only to return the snapshot path + cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") + + local_files_only = not update_model + + snapshot_kwargs = { + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + + log.debug(f"embedding_model: {model}") + log.debug(f"snapshot_kwargs: {snapshot_kwargs}") + + # Inspiration from upstream sentence_transformers + if ( + os.path.exists(model) + or ("\\" in model or model.count("/") > 1) + and local_files_only + ): + # If fully qualified path exists, return input, else set repo_id + return model + elif "/" not in model: + # Set valid repo_id for model short-name + model = "sentence-transformers" + "/" + model + + snapshot_kwargs["repo_id"] = model + + # Attempt to query the huggingface_hub library to determine the local path and/or to update + try: + model_repo_path = snapshot_download(**snapshot_kwargs) + log.debug(f"model_repo_path: {model_repo_path}") + return model_repo_path + except Exception as e: + log.exception(f"Cannot determine model snapshot path: {e}") + return model + + def generate_openai_embeddings( model: str, text: str, key: str, url: str = "https://api.openai.com/v1" ): diff --git a/backend/config.py b/backend/config.py index 013df5edd..622b95059 100644 --- a/backend/config.py +++ b/backend/config.py @@ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get( ) log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), +RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( + os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" +) + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) @@ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") if not RAG_RERANKING_MODEL == "": log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), +RAG_RERANKING_MODEL_AUTO_UPDATE = ( + os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" +) + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" )