diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index e7433f649..3d12084c8 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -13,7 +13,6 @@ import os, shutil, logging, re from pathlib import Path from typing import List -from sentence_transformers import SentenceTransformer from chromadb.utils import embedding_functions from langchain_community.document_loaders import ( @@ -45,7 +44,7 @@ from apps.web.models.documents import ( DocumentResponse, ) -from apps.rag.utils import query_doc, query_collection +from apps.rag.utils import query_doc, query_collection, embedding_model_get_path from utils.misc import ( calculate_sha256, @@ -60,6 +59,7 @@ from config import ( DOCS_DIR, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_DEVICE_TYPE, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, @@ -71,15 +71,6 @@ from constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -# -# if RAG_EMBEDDING_MODEL: -# sentence_transformer_ef = SentenceTransformer( -# model_name_or_path=RAG_EMBEDDING_MODEL, -# cache_folder=RAG_EMBEDDING_MODEL_DIR, -# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, -# ) - - app = FastAPI() app.state.PDF_EXTRACT_IMAGES = False @@ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE) app.state.TOP_K = 4 app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=app.state.RAG_EMBEDDING_MODEL, + model_name=app.state.RAG_EMBEDDING_MODEL_PATH, device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, ) ) @@ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)): return { "status": True, "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, } @@ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel): async def update_embedding_model( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): + status = True + old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model - app.state.sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=app.state.RAG_EMBEDDING_MODEL, - device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + + log.debug(f"form_data.embedding_model: {form_data.embedding_model}") + log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}") + + try: + app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True) + app.state.sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=app.state.RAG_EMBEDDING_MODEL_PATH, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + ) ) - ) + except Exception as e: + log.exception(f"Problem updating embedding model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=e, + ) + + if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path: + status = False + + log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}") + log.debug(f"old_model_path: {old_model_path}") + log.debug(f"status: {status}") return { - "status": True, + "status": status, "embedding_model": app.state.RAG_EMBEDDING_MODEL, + "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, } diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7b9e6628c..249048b3e 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,6 +1,8 @@ +import os import re import logging from typing import List +from huggingface_hub import snapshot_download from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function): messages[last_user_message_idx] = new_user_message return messages + +def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False): + # Construct huggingface_hub kwargs with local_files_only to return the snapshot path + cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") + local_files_only = not update_embedding_model + snapshot_kwargs = { + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + + log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}") + log.debug(f"embedding_model: {embedding_model}") + log.debug(f"update_embedding_model: {update_embedding_model}") + log.debug(f"local_files_only: {local_files_only}") + + # Inspiration from upstream sentence_transformers + if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only): + # If fully qualified path exists, return input, else set repo_id + return embedding_model + elif "/" not in embedding_model: + # Set valid repo_id for model short-name + embedding_model = "sentence-transformers" + "/" + embedding_model + + snapshot_kwargs["repo_id"] = embedding_model + + # Attempt to query the huggingface_hub library to determine the local path and/or to update + try: + embedding_model_repo_path = snapshot_download(**snapshot_kwargs) + log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") + return embedding_model_repo_path + except Exception as e: + log.exception(f"Cannot determine embedding model snapshot path: {e}") + return embedding_model diff --git a/backend/config.py b/backend/config.py index 39411d255..b3299a97f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -395,6 +395,9 @@ RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" ) +RAG_EMBEDDING_MODEL_AUTO_UPDATE = False +if os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true": + RAG_EMBEDDING_MODEL_AUTO_UPDATE = True CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False), diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 668fe227b..33c70e2b1 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -345,3 +345,64 @@ export const resetVectorDB = async (token: string) => { return res; }; + +export const getEmbeddingModel = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +type EmbeddingModelUpdateForm = { + embedding_model: string; +}; + +export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...payload + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index f00038de4..c20f20422 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -6,7 +6,9 @@ getQuerySettings, scanDocs, updateQuerySettings, - resetVectorDB + resetVectorDB, + getEmbeddingModel, + updateEmbeddingModel } from '$lib/apis/rag'; import { documents } from '$lib/stores'; @@ -18,6 +20,7 @@ export let saveHandler: Function; let loading = false; + let loading1 = false; let showResetConfirm = false; @@ -30,6 +33,10 @@ k: 4 }; + let embeddingModel = { + embedding_model: '', + }; + const scanHandler = async () => { loading = true; const res = await scanDocs(localStorage.token); @@ -41,6 +48,21 @@ } }; + const embeddingModelUpdateHandler = async () => { + loading1 = true; + const res = await updateEmbeddingModel(localStorage.token, embeddingModel); + loading1 = false; + + if (res) { + console.log('embeddingModelUpdateHandler:', res); + if (res.status == true) { + toast.success($i18n.t('Model {{embedding_model}} update complete!', res)); + } else { + toast.error($i18n.t('Model {{embedding_model}} update failed or not required!', res)); + } + } + }; + const submitHandler = async () => { const res = await updateRAGConfig(localStorage.token, { pdf_extract_images: pdfExtractImages, @@ -62,6 +84,8 @@ chunkOverlap = res.chunk.chunk_overlap; } + embeddingModel = await getEmbeddingModel(localStorage.token); + querySettings = await getQuerySettings(localStorage.token); }); @@ -137,6 +161,67 @@ {/if} + +
+
+ {$i18n.t('Update embedding model {{embedding_model}}', embeddingModel)} +
+ + +