From 885b9f1ece8f0fc9a4de3390164592c9dbc67282 Mon Sep 17 00:00:00 2001 From: Peter De-Ath Date: Thu, 26 Sep 2024 23:28:47 +0100 Subject: [PATCH] refactor: Update GenerateEmbeddingsForm to support batch processing refactor: Update embedding batch size handling in RAG configuration refactor: add query_doc query caching refactor: update logging statements in generate_chat_completion function change embedding_batch_size to Optional --- backend/open_webui/apps/ollama/main.py | 154 ++++++++---------- backend/open_webui/apps/retrieval/main.py | 24 ++- backend/open_webui/apps/retrieval/utils.py | 42 +++-- backend/open_webui/config.py | 8 +- src/lib/apis/retrieval/index.ts | 2 +- .../admin/Settings/Documents.svelte | 18 +- 6 files changed, 116 insertions(+), 132 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 33d984655..eed74064f 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel): class GenerateEmbedForm(BaseModel): model: str - input: str - truncate: Optional[bool] + input: list[str] + truncate: Optional[bool] = None options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None @@ -560,48 +560,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embed", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return generate_ollama_batch_embeddings(form_data, url_idx) @app.post("/api/embeddings") @@ -611,48 +570,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx is None: - model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - log.info(f"url: {url}") - - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - headers={"Content-Type": "application/json"}, - data=form_data.model_dump_json(exclude_none=True).encode(), - ) - try: - r.raise_for_status() - - return r.json() - except Exception as e: - log.exception(e) - error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except Exception: - error_detail = f"Ollama: {e}" - - raise HTTPException( - status_code=r.status_code if r else 500, - detail=error_detail, - ) + return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) def generate_ollama_embeddings( @@ -692,7 +610,64 @@ def generate_ollama_embeddings( log.info(f"generate_ollama_embeddings {data}") if "embedding" in data: - return data["embedding"] + return data + else: + raise Exception("Something went wrong :/") + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except Exception: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +def generate_ollama_batch_embeddings( + form_data: GenerateEmbedForm, + url_idx: Optional[int] = None, +): + log.info(f"generate_ollama_batch_embeddings {form_data}") + + if url_idx is None: + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + r = requests.request( + method="POST", + url=f"{url}/api/embed", + headers={"Content-Type": "application/json"}, + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + try: + r.raise_for_status() + + data = r.json() + + log.info(f"generate_ollama_batch_embeddings {data}") + + if "embeddings" in data: + return data else: raise Exception("Something went wrong :/") except Exception as e: @@ -788,8 +763,7 @@ async def generate_chat_completion( user=Depends(get_verified_user), ): payload = {**form_data.model_dump(exclude_none=True)} - log.debug(f"{payload = }") - + log.debug(f"generate_chat_completion() - 1.payload = {payload}") if "metadata" in payload: del payload["metadata"] @@ -824,7 +798,7 @@ async def generate_chat_completion( url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") - log.debug(payload) + log.debug(f"generate_chat_completion() - 2.payload = {payload}") return await post_streaming_url( f"{url}/api/chat", diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index b3e2ed7c6..c80b2011d 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -63,7 +63,7 @@ from open_webui.config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_OPENAI_BATCH_SIZE, + RAG_EMBEDDING_BATCH_SIZE, RAG_FILE_MAX_COUNT, RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, @@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL -app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE +app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_TEMPLATE = RAG_TEMPLATE @@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) app.add_middleware( @@ -267,7 +267,7 @@ async def get_status(): "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL, - "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, } @@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)): "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, - "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, }, } @@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)): class OpenAIConfigForm(BaseModel): url: str key: str - batch_size: Optional[int] = None class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None embedding_engine: str embedding_model: str + embedding_batch_size: Optional[int] = 1 @app.post("/embedding/update") @@ -320,11 +320,7 @@ async def update_embedding_config( if form_data.openai_config is not None: app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_KEY = form_data.openai_config.key - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = ( - form_data.openai_config.batch_size - if form_data.openai_config.batch_size - else 1 - ) + app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) @@ -334,17 +330,17 @@ async def update_embedding_config( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) return { "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, + "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, - "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, }, } except Exception as e: @@ -690,7 +686,7 @@ def save_docs_to_vector_db( app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) embeddings = embedding_function( diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 0fe206c96..b7d746bfa 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -12,8 +12,8 @@ from langchain_core.documents import Document from open_webui.apps.ollama.main import ( - GenerateEmbeddingsForm, - generate_ollama_embeddings, + GenerateEmbedForm, + generate_ollama_batch_embeddings, ) from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message @@ -71,7 +71,7 @@ def query_doc( try: result = VECTOR_DB_CLIENT.search( collection_name=collection_name, - vectors=[query_embedding], + vectors=query_embedding, limit=k, ) @@ -265,19 +265,15 @@ def get_embedding_function( embedding_function, openai_key, openai_url, - batch_size, + embedding_batch_size, ): if embedding_engine == "": return lambda query: embedding_function.encode(query).tolist() elif embedding_engine in ["ollama", "openai"]: if embedding_engine == "ollama": func = lambda query: generate_ollama_embeddings( - GenerateEmbeddingsForm( - **{ - "model": embedding_model, - "prompt": query, - } - ) + model=embedding_model, + input=query, ) elif embedding_engine == "openai": func = lambda query: generate_openai_embeddings( @@ -289,13 +285,10 @@ def get_embedding_function( def generate_multiple(query, f): if isinstance(query, list): - if embedding_engine == "openai": - embeddings = [] - for i in range(0, len(query), batch_size): - embeddings.extend(f(query[i : i + batch_size])) - return embeddings - else: - return [f(q) for q in query] + embeddings = [] + for i in range(0, len(query), embedding_batch_size): + embeddings.extend(f(query[i : i + embedding_batch_size])) + return embeddings else: return f(query) @@ -481,6 +474,21 @@ def generate_openai_batch_embeddings( return None +def generate_ollama_embeddings( + model: str, input: list[str] +) -> Optional[list[list[float]]]: + if isinstance(input, list): + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": input}) + ) + else: + embeddings = generate_ollama_batch_embeddings( + GenerateEmbedForm(**{"model": model, "input": [input]}) + ) + + return embeddings["embeddings"] + + import operator from typing import Optional, Sequence diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index bfc9a4ded..eac4855ce 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -986,10 +986,10 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" ) -RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig( - "RAG_EMBEDDING_OPENAI_BATCH_SIZE", - "rag.embedding_openai_batch_size", - int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")), +RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( + "RAG_EMBEDDING_BATCH_SIZE", + "rag.embedding_batch_size", + int(os.environ.get("RAG_EMBEDDING_BATCH_SIZE", "1")), ) RAG_RERANKING_MODEL = PersistentConfig( diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 9f49e9c0f..6c6b18b9f 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => { type OpenAIConfigForm = { key: string; url: string; - batch_size: number; }; type EmbeddingModelUpdateForm = { openai_config?: OpenAIConfigForm; embedding_engine: string; embedding_model: string; + embedding_batch_size?: number; }; export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index d6f7dc987..d94146c7d 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -38,6 +38,7 @@ let embeddingEngine = ''; let embeddingModel = ''; + let embeddingBatchSize = 1; let rerankingModel = ''; let fileMaxSize = null; @@ -53,7 +54,6 @@ let OpenAIKey = ''; let OpenAIUrl = ''; - let OpenAIBatchSize = 1; let querySettings = { template: '', @@ -100,12 +100,16 @@ const res = await updateEmbeddingConfig(localStorage.token, { embedding_engine: embeddingEngine, embedding_model: embeddingModel, + ...(embeddingEngine === 'openai' || embeddingEngine === 'ollama' + ? { + embedding_batch_size: embeddingBatchSize + } + : {}), ...(embeddingEngine === 'openai' ? { openai_config: { key: OpenAIKey, - url: OpenAIUrl, - batch_size: OpenAIBatchSize + url: OpenAIUrl } } : {}) @@ -193,10 +197,10 @@ if (embeddingConfig) { embeddingEngine = embeddingConfig.embedding_engine; embeddingModel = embeddingConfig.embedding_model; + embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1; OpenAIKey = embeddingConfig.openai_config.key; OpenAIUrl = embeddingConfig.openai_config.url; - OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1; } }; @@ -309,6 +313,8 @@ + {/if} + {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
{$i18n.t('Embedding Batch Size')}
@@ -318,13 +324,13 @@ min="1" max="2048" step="1" - bind:value={OpenAIBatchSize} + bind:value={embeddingBatchSize} class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700" />