diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 7e5e06004..f03aa4b7f 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -13,7 +13,6 @@ import os, shutil, logging, re from pathlib import Path from typing import List -from sentence_transformers import SentenceTransformer from chromadb.utils import embedding_functions from chromadb.utils.batch_utils import create_batches @@ -46,7 +45,7 @@ from apps.web.models.documents import ( DocumentResponse, ) -from apps.rag.utils import query_doc, query_collection +from apps.rag.utils import query_doc, query_collection, get_embedding_model_path from utils.misc import ( calculate_sha256, @@ -60,6 +59,7 @@ from config import ( UPLOAD_DIR, DOCS_DIR, RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, DEVICE_TYPE, CHROMA_CLIENT, CHUNK_SIZE, @@ -78,12 +78,18 @@ app.state.PDF_EXTRACT_IMAGES = False app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE + + app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL + + app.state.TOP_K = 4 app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=app.state.RAG_EMBEDDING_MODEL, + model_name=get_embedding_model_path( + app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE + ), device=DEVICE_TYPE, ) ) @@ -135,17 +141,33 @@ class EmbeddingModelUpdateForm(BaseModel): async def update_embedding_model( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model - app.state.sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=app.state.RAG_EMBEDDING_MODEL, - device=DEVICE_TYPE, - ) + + log.info( + f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) - return { - "status": True, - "embedding_model": app.state.RAG_EMBEDDING_MODEL, - } + + try: + sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=get_embedding_model_path(form_data.embedding_model, True), + device=DEVICE_TYPE, + ) + ) + + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = sentence_transformer_ef + + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + + except Exception as e: + log.exception(f"Problem updating embedding model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) @app.get("/config") diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7b9e6628c..7bbfe0b88 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -1,6 +1,8 @@ +import os import re import logging from typing import List +from huggingface_hub import snapshot_download from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -188,3 +190,43 @@ def rag_messages(docs, messages, template, k, embedding_function): messages[last_user_message_idx] = new_user_message return messages + + +def get_embedding_model_path( + embedding_model: str, update_embedding_model: bool = False +): + # Construct huggingface_hub kwargs with local_files_only to return the snapshot path + cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") + + local_files_only = not update_embedding_model + + snapshot_kwargs = { + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + + log.debug(f"embedding_model: {embedding_model}") + log.debug(f"snapshot_kwargs: {snapshot_kwargs}") + + # Inspiration from upstream sentence_transformers + if ( + os.path.exists(embedding_model) + or ("\\" in embedding_model or embedding_model.count("/") > 1) + and local_files_only + ): + # If fully qualified path exists, return input, else set repo_id + return embedding_model + elif "/" not in embedding_model: + # Set valid repo_id for model short-name + embedding_model = "sentence-transformers" + "/" + embedding_model + + snapshot_kwargs["repo_id"] = embedding_model + + # Attempt to query the huggingface_hub library to determine the local path and/or to update + try: + embedding_model_repo_path = snapshot_download(**snapshot_kwargs) + log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") + return embedding_model_repo_path + except Exception as e: + log.exception(f"Cannot determine embedding model snapshot path: {e}") + return embedding_model diff --git a/backend/config.py b/backend/config.py index 2a71b1895..6e3cf92a9 100644 --- a/backend/config.py +++ b/backend/config.py @@ -403,6 +403,12 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), + +RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( + os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" +) + + # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 668fe227b..33c70e2b1 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -345,3 +345,64 @@ export const resetVectorDB = async (token: string) => { return res; }; + +export const getEmbeddingModel = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +type EmbeddingModelUpdateForm = { + embedding_model: string; +}; + +export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...payload + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index f00038de4..c94c1250b 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -6,18 +6,23 @@ getQuerySettings, scanDocs, updateQuerySettings, - resetVectorDB + resetVectorDB, + getEmbeddingModel, + updateEmbeddingModel } from '$lib/apis/rag'; import { documents } from '$lib/stores'; import { onMount, getContext } from 'svelte'; import { toast } from 'svelte-sonner'; + import Tooltip from '$lib/components/common/Tooltip.svelte'; + const i18n = getContext('i18n'); export let saveHandler: Function; - let loading = false; + let scanDirLoading = false; + let updateEmbeddingModelLoading = false; let showResetConfirm = false; @@ -30,10 +35,12 @@ k: 4 }; + let embeddingModel = ''; + const scanHandler = async () => { - loading = true; + scanDirLoading = true; const res = await scanDocs(localStorage.token); - loading = false; + scanDirLoading = false; if (res) { await documents.set(await getDocs(localStorage.token)); @@ -41,6 +48,38 @@ } }; + const embeddingModelUpdateHandler = async () => { + if (embeddingModel.split('/').length - 1 > 1) { + toast.error( + $i18n.t( + 'Model filesystem path detected. Model shortname is required for update, cannot continue.' + ) + ); + return; + } + + console.log('Update embedding model attempt:', embeddingModel); + + updateEmbeddingModelLoading = true; + const res = await updateEmbeddingModel(localStorage.token, { + embedding_model: embeddingModel + }).catch(async (error) => { + toast.error(error); + embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model; + return null; + }); + updateEmbeddingModelLoading = false; + + if (res) { + console.log('embeddingModelUpdateHandler:', res); + if (res.status === true) { + toast.success($i18n.t('Model {{embedding_model}} update complete!', res), { + duration: 1000 * 10 + }); + } + } + }; + const submitHandler = async () => { const res = await updateRAGConfig(localStorage.token, { pdf_extract_images: pdfExtractImages, @@ -62,6 +101,8 @@ chunkOverlap = res.chunk.chunk_overlap; } + embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model; + querySettings = await getQuerySettings(localStorage.token); }); @@ -73,7 +114,7 @@ saveHandler(); }} > -