From caeb822cdcb1973f15ff30a1c3c457095be6a0c6 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 19 May 2025 03:40:32 +0400 Subject: [PATCH] feat: azure openai support --- backend/open_webui/routers/openai.py | 199 ++++++++++++------- src/lib/apis/ollama/index.ts | 9 +- src/lib/apis/openai/index.ts | 7 +- src/lib/components/AddConnectionModal.svelte | 66 +++--- src/lib/components/AddServerModal.svelte | 4 - 5 files changed, 174 insertions(+), 111 deletions(-) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index a196eca26..3d7f1d864 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -463,60 +463,88 @@ async def get_models( url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(url_idx), + request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support + ) + r = None async with aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - async with session.get( - f"{url}/models", - headers={ - "Authorization": f"Bearer {key}", - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": user.name, - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - }, - ssl=AIOHTTP_CLIENT_SESSION_SSL, - ) as r: - if r.status != 200: - # Extract response error details if available - error_detail = f"HTTP Error: {r.status}" - res = await r.json() - if "error" in res: - error_detail = f"External Error: {res['error']}" - raise Exception(error_detail) + headers = { + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + } - response_data = await r.json() + if api_config.get("azure", False): + headers["api-key"] = key - # Check if we're calling OpenAI API based on the URL - if "api.openai.com" in url: - # Filter models according to the specified conditions - response_data["data"] = [ - model - for model in response_data.get("data", []) - if not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ] + api_version = api_config.get("api_version", "2023-03-15-preview") + async with session.get( + f"{url}/openai/deployments?api-version={api_version}", + headers=headers, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) - models = response_data + response_data = await r.json() + models = response_data + else: + headers["Authorization"] = f"Bearer {key}" + + async with session.get( + f"{url}/models", + headers=headers, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + + # Check if we're calling OpenAI API based on the URL + if "api.openai.com" in url: + # Filter models according to the specified conditions + response_data["data"] = [ + model + for model in response_data.get("data", []) + if not any( + name in model["id"] + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ) + ] + + models = response_data except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") @@ -538,6 +566,8 @@ class ConnectionVerificationForm(BaseModel): url: str key: str + config: Optional[dict] = None + @router.post("/verify") async def verify_connection( @@ -546,39 +576,64 @@ async def verify_connection( url = form_data.url key = form_data.key + api_config = form_data.config or {} + async with aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - async with session.get( - f"{url}/models", - headers={ - "Authorization": f"Bearer {key}", - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": user.name, - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - }, - ssl=AIOHTTP_CLIENT_SESSION_SSL, - ) as r: - if r.status != 200: - # Extract response error details if available - error_detail = f"HTTP Error: {r.status}" - res = await r.json() - if "error" in res: - error_detail = f"External Error: {res['error']}" - raise Exception(error_detail) + headers = { + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + } - response_data = await r.json() - return response_data + if api_config.get("azure", False): + headers["api-key"] = key + + api_version = api_config.get("api_version", "2023-03-15-preview") + async with session.get( + f"{url}/openai/deployments?api-version={api_version}", + headers=headers, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data + else: + headers["Authorization"] = f"Bearer {key}" + + async with session.get( + f"{url}/models", + headers=headers, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 2f6278fe9..f159555da 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,10 +1,6 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; -export const verifyOllamaConnection = async ( - token: string = '', - url: string = '', - key: string = '' -) => { +export const verifyOllamaConnection = async (token: string = '', connection: dict = {}) => { let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/verify`, { @@ -15,8 +11,7 @@ export const verifyOllamaConnection = async ( 'Content-Type': 'application/json' }, body: JSON.stringify({ - url, - key + ...connection }) }) .then(async (res) => { diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index f6cf76a73..070118a1a 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -267,10 +267,10 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => { export const verifyOpenAIConnection = async ( token: string = '', - url: string = 'https://api.openai.com/v1', - key: string = '', + connection: dict = {}, direct: boolean = false ) => { + const { url, key, config } = connection; if (!url) { throw 'OpenAI: URL is required'; } @@ -309,7 +309,8 @@ export const verifyOpenAIConnection = async ( }, body: JSON.stringify({ url, - key + key, + config }) }) .then(async (res) => { diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index 628adf263..27665b356 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -33,7 +33,9 @@ let connectionType = 'external'; let azure = false; $: azure = - url.includes('openai.azure.com') || url.includes('cognitive.microsoft.com') ? true : false; + (url.includes('openai.azure.com') || url.includes('cognitive.microsoft.com')) && !direct + ? true + : false; let prefixId = ''; let enable = true; @@ -47,7 +49,10 @@ let loading = false; const verifyOllamaHandler = async () => { - const res = await verifyOllamaConnection(localStorage.token, url, key).catch((error) => { + const res = await verifyOllamaConnection(localStorage.token, { + url, + key + }).catch((error) => { toast.error(`${error}`); }); @@ -57,11 +62,20 @@ }; const verifyOpenAIHandler = async () => { - const res = await verifyOpenAIConnection(localStorage.token, url, key, direct).catch( - (error) => { - toast.error(`${error}`); - } - ); + const res = await verifyOpenAIConnection( + localStorage.token, + { + url, + key, + config: { + azure: azure, + api_version: apiVersion + } + }, + direct + ).catch((error) => { + toast.error(`${error}`); + }); if (res) { toast.success($i18n.t('Server connection verified')); @@ -187,27 +201,29 @@ }} >
-
-
-
{$i18n.t('Connection Type')}
+ {#if !direct} +
+
+
{$i18n.t('Connection Type')}
-
- +
+ +
-
+ {/if}
diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddServerModal.svelte index a5f0ca5c7..ff0a546fa 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddServerModal.svelte @@ -3,10 +3,6 @@ import { getContext, onMount } from 'svelte'; const i18n = getContext('i18n'); - import { models } from '$lib/stores'; - import { verifyOpenAIConnection } from '$lib/apis/openai'; - import { verifyOllamaConnection } from '$lib/apis/ollama'; - import Modal from '$lib/components/common/Modal.svelte'; import Plus from '$lib/components/icons/Plus.svelte'; import Minus from '$lib/components/icons/Minus.svelte';