diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 950a379cd..c631e2609 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2184,6 +2184,27 @@ RAG_OPENAI_API_KEY = PersistentConfig( os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), ) +RAG_AZURE_OPENAI_BASE_URL = PersistentConfig( + "RAG_AZURE_OPENAI_BASE_URL", + "rag.azure_openai.base_url", + os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""), +) +RAG_AZURE_OPENAI_API_KEY = PersistentConfig( + "RAG_AZURE_OPENAI_API_KEY", + "rag.azure_openai.api_key", + os.getenv("RAG_AZURE_OPENAI_API_KEY", ""), +) +RAG_AZURE_OPENAI_DEPLOYMENT = PersistentConfig( + "RAG_AZURE_OPENAI_DEPLOYMENT", + "rag.azure_openai.deployment", + os.getenv("RAG_AZURE_OPENAI_DEPLOYMENT", ""), +) +RAG_AZURE_OPENAI_VERSION = PersistentConfig( + "RAG_AZURE_OPENAI_VERSION", + "rag.azure_openai.version", + os.getenv("RAG_AZURE_OPENAI_VERSION", ""), +) + RAG_OLLAMA_BASE_URL = PersistentConfig( "RAG_OLLAMA_BASE_URL", "rag.ollama.url", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index b57ed59f2..bebcdc1be 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -207,6 +207,10 @@ from open_webui.config import ( RAG_FILE_MAX_SIZE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, + RAG_AZURE_OPENAI_BASE_URL, + RAG_AZURE_OPENAI_API_KEY, + RAG_AZURE_OPENAI_DEPLOYMENT, + RAG_AZURE_OPENAI_VERSION, RAG_OLLAMA_BASE_URL, RAG_OLLAMA_API_KEY, CHUNK_OVERLAP, @@ -717,6 +721,11 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL +app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY +app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = RAG_AZURE_OPENAI_DEPLOYMENT +app.state.config.RAG_AZURE_OPENAI_VERSION = RAG_AZURE_OPENAI_VERSION + app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY @@ -811,14 +820,32 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( ( app.state.config.RAG_OPENAI_API_BASE_URL if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_BASE_URL + else ( + app.state.config.RAG_OLLAMA_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else app.state.config.RAG_AZURE_OPENAI_BASE_URL + ) ), ( app.state.config.RAG_OPENAI_API_KEY if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_API_KEY + else ( + app.state.config.RAG_OLLAMA_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else app.state.config.RAG_AZURE_OPENAI_API_KEY + ) ), app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ( + app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT + if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), + ( + app.state.config.RAG_AZURE_OPENAI_VERSION + if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), ) ######################################## diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 97a89880c..a1216d4bd 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -5,6 +5,7 @@ from typing import Optional, Union import requests import hashlib from concurrent.futures import ThreadPoolExecutor +import time from huggingface_hub import snapshot_download from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever @@ -400,12 +401,14 @@ def get_embedding_function( url, key, embedding_batch_size, + deployment=None, + version=None, ): if embedding_engine == "": return lambda query, prefix=None, user=None: embedding_function.encode( query, **({"prompt": prefix} if prefix else {}) ).tolist() - elif embedding_engine in ["ollama", "openai"]: + elif embedding_engine in ["ollama", "openai", "azure_openai"]: func = lambda query, prefix=None, user=None: generate_embeddings( engine=embedding_engine, model=embedding_model, @@ -414,6 +417,8 @@ def get_embedding_function( url=url, key=key, user=user, + deployment=deployment, + version=version, ) def generate_multiple(query, prefix, user, func): @@ -697,6 +702,61 @@ def generate_openai_batch_embeddings( return None +def generate_azure_openai_batch_embeddings( + deployment: str, + texts: list[str], + url: str, + key: str = "", + model: str = "", + version: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}" + ) + json_data = {"input": texts, "model": model} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}" + + for _ in range(5): + r = requests.post( + url, + headers={ + "Content-Type": "application/json", + "api-key": key, + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, + json=json_data, + ) + if r.status_code == 429: + retry = float(r.headers.get("Retry-After", "1")) + time.sleep(retry) + continue + r.raise_for_status() + data = r.json() + if "data" in data: + return [elem["embedding"] for elem in data["data"]] + else: + raise Exception("Something went wrong :/") + return None + except Exception as e: + log.exception(f"Error generating azure openai batch embeddings: {e}") + return None + + def generate_ollama_batch_embeddings( model: str, texts: list[str], @@ -794,6 +854,32 @@ def generate_embeddings( model, [text], url, key, prefix, user ) return embeddings[0] if isinstance(text, str) else embeddings + elif engine == "azure_openai": + deployment = kwargs.get("deployment", "") + version = kwargs.get("version", "") + if isinstance(text, list): + embeddings = generate_azure_openai_batch_embeddings( + deployment, + text, + url, + key, + model, + version, + prefix, + user, + ) + else: + embeddings = generate_azure_openai_batch_embeddings( + deployment, + [text], + url, + key, + model, + version, + prefix, + user, + ) + return embeddings[0] if isinstance(text, str) else embeddings import operator diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index d652ff025..22ef29951 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -239,6 +239,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)): "url": request.app.state.config.RAG_OLLAMA_BASE_URL, "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, + "azure_openai_config": { + "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, + "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, + "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT, + "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION, + }, } @@ -252,9 +258,17 @@ class OllamaConfigForm(BaseModel): key: str +class AzureOpenAIConfigForm(BaseModel): + url: str + key: str + deployment: str + version: str + + class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None ollama_config: Optional[OllamaConfigForm] = None + azure_openai_config: Optional[AzureOpenAIConfigForm] = None embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 @@ -271,7 +285,7 @@ async def update_embedding_config( request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]: if form_data.openai_config is not None: request.app.state.config.RAG_OPENAI_API_BASE_URL = ( form_data.openai_config.url @@ -288,6 +302,20 @@ async def update_embedding_config( form_data.ollama_config.key ) + if form_data.azure_openai_config is not None: + request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( + form_data.azure_openai_config.url + ) + request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( + form_data.azure_openai_config.key + ) + request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = ( + form_data.azure_openai_config.deployment + ) + request.app.state.config.RAG_AZURE_OPENAI_VERSION = ( + form_data.azure_openai_config.version + ) + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( form_data.embedding_batch_size ) @@ -304,14 +332,32 @@ async def update_embedding_config( ( request.app.state.config.RAG_OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_BASE_URL + else ( + request.app.state.config.RAG_OLLAMA_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL + ) ), ( request.app.state.config.RAG_OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY + else ( + request.app.state.config.RAG_OLLAMA_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_API_KEY + ) ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ( + request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), + ( + request.app.state.config.RAG_AZURE_OPENAI_VERSION + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), ) return { @@ -327,6 +373,12 @@ async def update_embedding_config( "url": request.app.state.config.RAG_OLLAMA_BASE_URL, "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, + "azure_openai_config": { + "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, + "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, + "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT, + "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION, + }, } except Exception as e: log.exception(f"Problem updating embedding model: {e}") @@ -1129,14 +1181,32 @@ def save_docs_to_vector_db( ( request.app.state.config.RAG_OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_BASE_URL + else ( + request.app.state.config.RAG_OLLAMA_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL + ) ), ( request.app.state.config.RAG_OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY + else ( + request.app.state.config.RAG_OLLAMA_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_API_KEY + ) ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + ( + request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), + ( + request.app.state.config.RAG_AZURE_OPENAI_VERSION + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), ) embeddings = embedding_function( diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 8fa6578ed..1c26cb7e8 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -180,15 +180,23 @@ export const getEmbeddingConfig = async (token: string) => { }; type OpenAIConfigForm = { - key: string; - url: string; + key: string; + url: string; +}; + +type AzureOpenAIConfigForm = { + key: string; + url: string; + deployment: string; + version: string; }; type EmbeddingModelUpdateForm = { - openai_config?: OpenAIConfigForm; - embedding_engine: string; - embedding_model: string; - embedding_batch_size?: number; + openai_config?: OpenAIConfigForm; + azure_openai_config?: AzureOpenAIConfigForm; + embedding_engine: string; + embedding_model: string; + embedding_batch_size?: number; }; export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index f4f3202d7..863416b07 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -43,8 +43,13 @@ let embeddingBatchSize = 1; let rerankingModel = ''; - let OpenAIUrl = ''; - let OpenAIKey = ''; + let OpenAIUrl = ''; + let OpenAIKey = ''; + + let AzureOpenAIUrl = ''; + let AzureOpenAIKey = ''; + let AzureOpenAIDeployment = ''; + let AzureOpenAIVersion = ''; let OllamaUrl = ''; let OllamaKey = ''; @@ -86,27 +91,40 @@ return; } - if ((embeddingEngine === 'openai' && OpenAIKey === '') || OpenAIUrl === '') { - toast.error($i18n.t('OpenAI URL/Key required.')); - return; - } + if (embeddingEngine === 'openai' && (OpenAIKey === '' || OpenAIUrl === '')) { + toast.error($i18n.t('OpenAI URL/Key required.')); + return; + } + if ( + embeddingEngine === 'azure_openai' && + (AzureOpenAIKey === '' || AzureOpenAIUrl === '' || AzureOpenAIDeployment === '' || AzureOpenAIVersion === '') + ) { + toast.error($i18n.t('OpenAI URL/Key required.')); + return; + } console.debug('Update embedding model attempt:', embeddingModel); updateEmbeddingModelLoading = true; - const res = await updateEmbeddingConfig(localStorage.token, { - embedding_engine: embeddingEngine, - embedding_model: embeddingModel, - embedding_batch_size: embeddingBatchSize, - ollama_config: { - key: OllamaKey, - url: OllamaUrl - }, - openai_config: { - key: OpenAIKey, - url: OpenAIUrl - } - }).catch(async (error) => { + const res = await updateEmbeddingConfig(localStorage.token, { + embedding_engine: embeddingEngine, + embedding_model: embeddingModel, + embedding_batch_size: embeddingBatchSize, + ollama_config: { + key: OllamaKey, + url: OllamaUrl + }, + openai_config: { + key: OpenAIKey, + url: OpenAIUrl + }, + azure_openai_config: { + key: AzureOpenAIKey, + url: AzureOpenAIUrl, + deployment: AzureOpenAIDeployment, + version: AzureOpenAIVersion + } + }).catch(async (error) => { toast.error(`${error}`); await setEmbeddingConfig(); return null; @@ -200,13 +218,18 @@ embeddingModel = embeddingConfig.embedding_model; embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1; - OpenAIKey = embeddingConfig.openai_config.key; - OpenAIUrl = embeddingConfig.openai_config.url; + OpenAIKey = embeddingConfig.openai_config.key; + OpenAIUrl = embeddingConfig.openai_config.url; - OllamaKey = embeddingConfig.ollama_config.key; - OllamaUrl = embeddingConfig.ollama_config.url; - } - }; + OllamaKey = embeddingConfig.ollama_config.key; + OllamaUrl = embeddingConfig.ollama_config.url; + + AzureOpenAIKey = embeddingConfig.azure_openai_config.key; + AzureOpenAIUrl = embeddingConfig.azure_openai_config.url; + AzureOpenAIDeployment = embeddingConfig.azure_openai_config.deployment; + AzureOpenAIVersion = embeddingConfig.azure_openai_config.version; + } + }; onMount(async () => { await setEmbeddingConfig(); @@ -603,23 +626,26 @@ bind:value={embeddingEngine} placeholder="Select an embedding model engine" on:change={(e) => { - if (e.target.value === 'ollama') { - embeddingModel = ''; - } else if (e.target.value === 'openai') { - embeddingModel = 'text-embedding-3-small'; - } else if (e.target.value === '') { - embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2'; - } + if (e.target.value === 'ollama') { + embeddingModel = ''; + } else if (e.target.value === 'openai') { + embeddingModel = 'text-embedding-3-small'; + } else if (e.target.value === 'azure_openai') { + embeddingModel = 'text-embedding-3-small'; + } else if (e.target.value === '') { + embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2'; + } }} > - + + - {#if embeddingEngine === 'openai'} + {#if embeddingEngine === 'openai'}
- {:else if embeddingEngine === 'ollama'} + {:else if embeddingEngine === 'ollama'}
- {/if} + {:else if embeddingEngine === 'azure_openai'} +
+
+ + +
+
+ + +
+
+ {/if}
@@ -741,7 +793,7 @@
- {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'} + {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai' || embeddingEngine === 'azure_openai'}
{$i18n.t('Embedding Batch Size')}