Merge pull request #14370 from daw/feat/add-azure-openai-embeddings-option

feat:Add Azure OpenAI embedding support
This commit is contained in:
Tim Jaeryang Baek 2025-05-30 00:18:55 +04:00 committed by GitHub
commit ff353578db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 315 additions and 51 deletions

View File

@ -2184,6 +2184,27 @@ RAG_OPENAI_API_KEY = PersistentConfig(
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
)
RAG_AZURE_OPENAI_BASE_URL = PersistentConfig(
"RAG_AZURE_OPENAI_BASE_URL",
"rag.azure_openai.base_url",
os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""),
)
RAG_AZURE_OPENAI_API_KEY = PersistentConfig(
"RAG_AZURE_OPENAI_API_KEY",
"rag.azure_openai.api_key",
os.getenv("RAG_AZURE_OPENAI_API_KEY", ""),
)
RAG_AZURE_OPENAI_DEPLOYMENT = PersistentConfig(
"RAG_AZURE_OPENAI_DEPLOYMENT",
"rag.azure_openai.deployment",
os.getenv("RAG_AZURE_OPENAI_DEPLOYMENT", ""),
)
RAG_AZURE_OPENAI_VERSION = PersistentConfig(
"RAG_AZURE_OPENAI_VERSION",
"rag.azure_openai.version",
os.getenv("RAG_AZURE_OPENAI_VERSION", ""),
)
RAG_OLLAMA_BASE_URL = PersistentConfig(
"RAG_OLLAMA_BASE_URL",
"rag.ollama.url",

View File

@ -207,6 +207,10 @@ from open_webui.config import (
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
RAG_AZURE_OPENAI_BASE_URL,
RAG_AZURE_OPENAI_API_KEY,
RAG_AZURE_OPENAI_DEPLOYMENT,
RAG_AZURE_OPENAI_VERSION,
RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY,
CHUNK_OVERLAP,
@ -717,6 +721,11 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL
app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY
app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = RAG_AZURE_OPENAI_DEPLOYMENT
app.state.config.RAG_AZURE_OPENAI_VERSION = RAG_AZURE_OPENAI_VERSION
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
@ -811,14 +820,32 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
(
app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.RAG_OLLAMA_BASE_URL
else (
app.state.config.RAG_OLLAMA_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.RAG_OLLAMA_API_KEY
else (
app.state.config.RAG_OLLAMA_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
(
app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
(
app.state.config.RAG_AZURE_OPENAI_VERSION
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
########################################

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import requests
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
@ -400,12 +401,14 @@ def get_embedding_function(
url,
key,
embedding_batch_size,
deployment=None,
version=None,
):
if embedding_engine == "":
return lambda query, prefix=None, user=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
elif embedding_engine in ["ollama", "openai"]:
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
@ -414,6 +417,8 @@ def get_embedding_function(
url=url,
key=key,
user=user,
deployment=deployment,
version=version,
)
def generate_multiple(query, prefix, user, func):
@ -697,6 +702,61 @@ def generate_openai_batch_embeddings(
return None
def generate_azure_openai_batch_embeddings(
deployment: str,
texts: list[str],
url: str,
key: str = "",
model: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}"
for _ in range(5):
r = requests.post(
url,
headers={
"Content-Type": "application/json",
"api-key": key,
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
@ -794,6 +854,32 @@ def generate_embeddings(
model, [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
deployment = kwargs.get("deployment", "")
version = kwargs.get("version", "")
if isinstance(text, list):
embeddings = generate_azure_openai_batch_embeddings(
deployment,
text,
url,
key,
model,
version,
prefix,
user,
)
else:
embeddings = generate_azure_openai_batch_embeddings(
deployment,
[text],
url,
key,
model,
version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
import operator

View File

@ -239,6 +239,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"azure_openai_config": {
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
"deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
"version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
},
}
@ -252,9 +258,17 @@ class OllamaConfigForm(BaseModel):
key: str
class AzureOpenAIConfigForm(BaseModel):
url: str
key: str
deployment: str
version: str
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
ollama_config: Optional[OllamaConfigForm] = None
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@ -271,7 +285,7 @@ async def update_embedding_config(
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]:
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url
@ -288,6 +302,20 @@ async def update_embedding_config(
form_data.ollama_config.key
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = (
form_data.azure_openai_config.deployment
)
request.app.state.config.RAG_AZURE_OPENAI_VERSION = (
form_data.azure_openai_config.version
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
form_data.embedding_batch_size
)
@ -304,14 +332,32 @@ async def update_embedding_config(
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
else (
request.app.state.config.RAG_OLLAMA_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
else (
request.app.state.config.RAG_OLLAMA_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
(
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
(
request.app.state.config.RAG_AZURE_OPENAI_VERSION
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
return {
@ -327,6 +373,12 @@ async def update_embedding_config(
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"azure_openai_config": {
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
"deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
"version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
},
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
@ -1129,14 +1181,32 @@ def save_docs_to_vector_db(
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
else (
request.app.state.config.RAG_OLLAMA_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
else (
request.app.state.config.RAG_OLLAMA_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
(
request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
(
request.app.state.config.RAG_AZURE_OPENAI_VERSION
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
embeddings = embedding_function(

View File

@ -180,15 +180,23 @@ export const getEmbeddingConfig = async (token: string) => {
};
type OpenAIConfigForm = {
key: string;
url: string;
key: string;
url: string;
};
type AzureOpenAIConfigForm = {
key: string;
url: string;
deployment: string;
version: string;
};
type EmbeddingModelUpdateForm = {
openai_config?: OpenAIConfigForm;
embedding_engine: string;
embedding_model: string;
embedding_batch_size?: number;
openai_config?: OpenAIConfigForm;
azure_openai_config?: AzureOpenAIConfigForm;
embedding_engine: string;
embedding_model: string;
embedding_batch_size?: number;
};
export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {

View File

@ -43,8 +43,13 @@
let embeddingBatchSize = 1;
let rerankingModel = '';
let OpenAIUrl = '';
let OpenAIKey = '';
let OpenAIUrl = '';
let OpenAIKey = '';
let AzureOpenAIUrl = '';
let AzureOpenAIKey = '';
let AzureOpenAIDeployment = '';
let AzureOpenAIVersion = '';
let OllamaUrl = '';
let OllamaKey = '';
@ -86,27 +91,40 @@
return;
}
if ((embeddingEngine === 'openai' && OpenAIKey === '') || OpenAIUrl === '') {
toast.error($i18n.t('OpenAI URL/Key required.'));
return;
}
if (embeddingEngine === 'openai' && (OpenAIKey === '' || OpenAIUrl === '')) {
toast.error($i18n.t('OpenAI URL/Key required.'));
return;
}
if (
embeddingEngine === 'azure_openai' &&
(AzureOpenAIKey === '' || AzureOpenAIUrl === '' || AzureOpenAIDeployment === '' || AzureOpenAIVersion === '')
) {
toast.error($i18n.t('OpenAI URL/Key required.'));
return;
}
console.debug('Update embedding model attempt:', embeddingModel);
updateEmbeddingModelLoading = true;
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_model: embeddingModel,
embedding_batch_size: embeddingBatchSize,
ollama_config: {
key: OllamaKey,
url: OllamaUrl
},
openai_config: {
key: OpenAIKey,
url: OpenAIUrl
}
}).catch(async (error) => {
const res = await updateEmbeddingConfig(localStorage.token, {
embedding_engine: embeddingEngine,
embedding_model: embeddingModel,
embedding_batch_size: embeddingBatchSize,
ollama_config: {
key: OllamaKey,
url: OllamaUrl
},
openai_config: {
key: OpenAIKey,
url: OpenAIUrl
},
azure_openai_config: {
key: AzureOpenAIKey,
url: AzureOpenAIUrl,
deployment: AzureOpenAIDeployment,
version: AzureOpenAIVersion
}
}).catch(async (error) => {
toast.error(`${error}`);
await setEmbeddingConfig();
return null;
@ -200,13 +218,18 @@
embeddingModel = embeddingConfig.embedding_model;
embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
OpenAIKey = embeddingConfig.openai_config.key;
OpenAIUrl = embeddingConfig.openai_config.url;
OpenAIKey = embeddingConfig.openai_config.key;
OpenAIUrl = embeddingConfig.openai_config.url;
OllamaKey = embeddingConfig.ollama_config.key;
OllamaUrl = embeddingConfig.ollama_config.url;
}
};
OllamaKey = embeddingConfig.ollama_config.key;
OllamaUrl = embeddingConfig.ollama_config.url;
AzureOpenAIKey = embeddingConfig.azure_openai_config.key;
AzureOpenAIUrl = embeddingConfig.azure_openai_config.url;
AzureOpenAIDeployment = embeddingConfig.azure_openai_config.deployment;
AzureOpenAIVersion = embeddingConfig.azure_openai_config.version;
}
};
onMount(async () => {
await setEmbeddingConfig();
@ -603,23 +626,26 @@
bind:value={embeddingEngine}
placeholder="Select an embedding model engine"
on:change={(e) => {
if (e.target.value === 'ollama') {
embeddingModel = '';
} else if (e.target.value === 'openai') {
embeddingModel = 'text-embedding-3-small';
} else if (e.target.value === '') {
embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
}
if (e.target.value === 'ollama') {
embeddingModel = '';
} else if (e.target.value === 'openai') {
embeddingModel = 'text-embedding-3-small';
} else if (e.target.value === 'azure_openai') {
embeddingModel = 'text-embedding-3-small';
} else if (e.target.value === '') {
embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
}
}}
>
<option value="">{$i18n.t('Default (SentenceTransformers)')}</option>
<option value="ollama">{$i18n.t('Ollama')}</option>
<option value="openai">{$i18n.t('OpenAI')}</option>
<option value="openai">{$i18n.t('OpenAI')}</option>
<option value="azure_openai">Azure OpenAI</option>
</select>
</div>
</div>
{#if embeddingEngine === 'openai'}
{#if embeddingEngine === 'openai'}
<div class="my-0.5 flex gap-2 pr-2">
<input
class="flex-1 w-full text-sm bg-transparent outline-hidden"
@ -630,7 +656,7 @@
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
</div>
{:else if embeddingEngine === 'ollama'}
{:else if embeddingEngine === 'ollama'}
<div class="my-0.5 flex gap-2 pr-2">
<input
class="flex-1 w-full text-sm bg-transparent outline-hidden"
@ -645,7 +671,33 @@
required={false}
/>
</div>
{/if}
{:else if embeddingEngine === 'azure_openai'}
<div class="my-0.5 flex flex-col gap-2 pr-2 w-full">
<div class="flex gap-2">
<input
class="flex-1 w-full text-sm bg-transparent outline-hidden"
placeholder={$i18n.t('API Base URL')}
bind:value={AzureOpenAIUrl}
required
/>
<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={AzureOpenAIKey} />
</div>
<div class="flex gap-2">
<input
class="flex-1 w-full text-sm bg-transparent outline-hidden"
placeholder="Deployment"
bind:value={AzureOpenAIDeployment}
required
/>
<input
class="flex-1 w-full text-sm bg-transparent outline-hidden"
placeholder="Version"
bind:value={AzureOpenAIVersion}
required
/>
</div>
</div>
{/if}
</div>
<div class=" mb-2.5 flex flex-col w-full">
@ -741,7 +793,7 @@
</div>
</div>
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai' || embeddingEngine === 'azure_openai'}
<div class=" mb-2.5 flex w-full justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Embedding Batch Size')}