diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 7140cad9d..387ff05da 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -654,6 +654,60 @@ async def generate_embeddings( ) +def generate_ollama_embeddings( + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, +): + + log.info("generate_ollama_embeddings", form_data) + + if url_idx == None: + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + data = r.json() + + log.info("generate_ollama_embeddings", data) + + if "embedding" in data: + return data["embedding"] + else: + raise "Something went wrong :/" + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise error_detail + + class GenerateCompletionForm(BaseModel): model: str prompt: str diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f03aa4b7f..976c7735b 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -39,13 +39,21 @@ import uuid import json +from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm + from apps.web.models.documents import ( Documents, DocumentForm, DocumentResponse, ) -from apps.rag.utils import query_doc, query_collection, get_embedding_model_path +from apps.rag.utils import ( + query_doc, + query_embeddings_doc, + query_collection, + query_embeddings_collection, + get_embedding_model_path, +) from utils.misc import ( calculate_sha256, @@ -58,6 +66,7 @@ from config import ( SRC_LOG_LEVELS, UPLOAD_DIR, DOCS_DIR, + RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, DEVICE_TYPE, @@ -74,17 +83,20 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) app = FastAPI() -app.state.PDF_EXTRACT_IMAGES = False + +app.state.TOP_K = 4 app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP + + +app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_TEMPLATE = RAG_TEMPLATE -app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.PDF_EXTRACT_IMAGES = False -app.state.TOP_K = 4 - app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( model_name=get_embedding_model_path( @@ -121,24 +133,27 @@ async def get_status(): "chunk_size": app.state.CHUNK_SIZE, "chunk_overlap": app.state.CHUNK_OVERLAP, "template": app.state.RAG_TEMPLATE, + "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, } -@app.get("/embedding/model") -async def get_embedding_model(user=Depends(get_admin_user)): +@app.get("/embedding") +async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, + "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, } class EmbeddingModelUpdateForm(BaseModel): + embedding_engine: str embedding_model: str -@app.post("/embedding/model/update") -async def update_embedding_model( +@app.post("/embedding/update") +async def update_embedding_config( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): @@ -147,18 +162,26 @@ async def update_embedding_model( ) try: - sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=get_embedding_model_path(form_data.embedding_model, True), - device=DEVICE_TYPE, - ) - ) + app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model - app.state.sentence_transformer_ef = sentence_transformer_ef + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = None + else: + sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=get_embedding_model_path( + form_data.embedding_model, True + ), + device=DEVICE_TYPE, + ) + ) + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = sentence_transformer_ef return { "status": True, + "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, } @@ -252,12 +275,28 @@ def query_doc_handler( ): try: - return query_doc( - collection_name=form_data.collection_name, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, - ) + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) + ) + + return query_embeddings_doc( + collection_name=form_data.collection_name, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) + else: + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, + embedding_function=app.state.sentence_transformer_ef, + ) except Exception as e: log.exception(e) raise HTTPException( @@ -277,12 +316,35 @@ def query_collection_handler( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): - return query_collection( - collection_names=form_data.collection_names, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, - ) + try: + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) + ) + + return query_embeddings_collection( + collection_names=form_data.collection_names, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) + else: + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, + embedding_function=app.state.sentence_transformer_ef, + ) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) @app.post("/web") @@ -317,9 +379,11 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b chunk_overlap=app.state.CHUNK_OVERLAP, add_start_index=True, ) + docs = text_splitter.split_documents(data) if len(docs) > 0: + log.info("store_data_in_vector_db", "store_docs_in_vector_db") return store_docs_in_vector_db(docs, collection_name, overwrite), None else: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) @@ -338,6 +402,7 @@ def store_text_in_vector_db( def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: + log.info("store_docs_in_vector_db", docs, collection_name) texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] @@ -349,20 +414,39 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b log.info(f"deleting existing collection {collection_name}") CHROMA_CLIENT.delete_collection(name=collection_name) - collection = CHROMA_CLIENT.create_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + collection = CHROMA_CLIENT.create_collection(name=collection_name) - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid1()) for _ in texts], - metadatas=metadatas, - documents=texts, - ): - collection.add(*batch) + for batch in create_batches( + api=CHROMA_CLIENT, + ids=[str(uuid.uuid1()) for _ in texts], + metadatas=metadatas, + embeddings=[ + generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{"model": RAG_EMBEDDING_MODEL, "prompt": text} + ) + ) + for text in texts + ], + ): + collection.add(*batch) + else: - return True + collection = CHROMA_CLIENT.create_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) + + for batch in create_batches( + api=CHROMA_CLIENT, + ids=[str(uuid.uuid1()) for _ in texts], + metadatas=metadatas, + documents=texts, + ): + collection.add(*batch) + + return True except Exception as e: log.exception(e) if e.__class__.__name__ == "UniqueConstraintError": diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7bbfe0b88..17d8e4a9a 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -2,6 +2,9 @@ import os import re import logging from typing import List +import requests + + from huggingface_hub import snapshot_download from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -26,6 +29,22 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): raise e +def query_embeddings_doc(collection_name: str, query_embeddings, k: int): + try: + # if you use docker use the model from the environment variable + log.info("query_embeddings_doc", query_embeddings) + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query( + query_embeddings=[query_embeddings], + n_results=k, + ) + return result + except Exception as e: + raise e + + def merge_and_sort_query_results(query_results, k): # Initialize lists to store combined data combined_ids = [] @@ -96,6 +115,26 @@ def query_collection( return merge_and_sort_query_results(results, k) +def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): + + results = [] + log.info("query_embeddings_collection", query_embeddings) + + for collection_name in collection_names: + try: + collection = CHROMA_CLIENT.get_collection(name=collection_name) + + result = collection.query( + query_embeddings=[query_embeddings], + n_results=k, + ) + results.append(result) + except: + pass + + return merge_and_sort_query_results(results, k) + + def rag_template(template: str, context: str, query: str): template = template.replace("[context]", context) template = template.replace("[query]", query) diff --git a/backend/config.py b/backend/config.py index 6d93115bb..938df9961 100644 --- a/backend/config.py +++ b/backend/config.py @@ -405,6 +405,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) + +RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") + RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 4618acc4d..a94aceace 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -220,6 +220,32 @@ export const generatePrompt = async (token: string = '', model: string, conversa return res; }; +export const generateEmbeddings = async (token: string = '', model: string, text: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/embeddings`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: text + }) + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const generateTextCompletion = async (token: string = '', model: string, text: string) => { let error = null; diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 33c70e2b1..bfcee55fc 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -346,10 +346,10 @@ export const resetVectorDB = async (token: string) => { return res; }; -export const getEmbeddingModel = async (token: string) => { +export const getEmbeddingConfig = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, { + const res = await fetch(`${RAG_API_BASE_URL}/embedding`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -374,13 +374,14 @@ export const getEmbeddingModel = async (token: string) => { }; type EmbeddingModelUpdateForm = { + embedding_engine: string; embedding_model: string; }; -export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => { +export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, { + const res = await fetch(`${RAG_API_BASE_URL}/embedding/update`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c94c1250b..c9142fbe5 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -7,11 +7,11 @@ scanDocs, updateQuerySettings, resetVectorDB, - getEmbeddingModel, - updateEmbeddingModel + getEmbeddingConfig, + updateEmbeddingConfig } from '$lib/apis/rag'; - import { documents } from '$lib/stores'; + import { documents, models } from '$lib/stores'; import { onMount, getContext } from 'svelte'; import { toast } from 'svelte-sonner'; @@ -26,6 +26,9 @@ let showResetConfirm = false; + let embeddingEngine = ''; + let embeddingModel = ''; + let chunkSize = 0; let chunkOverlap = 0; let pdfExtractImages = true; @@ -35,8 +38,6 @@ k: 4 }; - let embeddingModel = ''; - const scanHandler = async () => { scanDirLoading = true; const res = await scanDocs(localStorage.token); @@ -49,7 +50,16 @@ }; const embeddingModelUpdateHandler = async () => { - if (embeddingModel.split('/').length - 1 > 1) { + if (embeddingModel === '') { + toast.error( + $i18n.t( + 'Model filesystem path detected. Model shortname is required for update, cannot continue.' + ) + ); + return; + } + + if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) { toast.error( $i18n.t( 'Model filesystem path detected. Model shortname is required for update, cannot continue.' @@ -61,11 +71,17 @@ console.log('Update embedding model attempt:', embeddingModel); updateEmbeddingModelLoading = true; - const res = await updateEmbeddingModel(localStorage.token, { + const res = await updateEmbeddingConfig(localStorage.token, { + embedding_engine: embeddingEngine, embedding_model: embeddingModel }).catch(async (error) => { toast.error(error); - embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model; + + const embeddingConfig = await getEmbeddingConfig(localStorage.token); + if (embeddingConfig) { + embeddingEngine = embeddingConfig.embedding_engine; + embeddingModel = embeddingConfig.embedding_model; + } return null; }); updateEmbeddingModelLoading = false; @@ -101,7 +117,12 @@ chunkOverlap = res.chunk.chunk_overlap; } - embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model; + const embeddingConfig = await getEmbeddingConfig(localStorage.token); + + if (embeddingConfig) { + embeddingEngine = embeddingConfig.embedding_engine; + embeddingModel = embeddingConfig.embedding_model; + } querySettings = await getQuerySettings(localStorage.token); }); @@ -118,81 +139,189 @@