mirror of
https://github.com/open-webui/open-webui
synced 2025-06-14 10:20:52 +00:00
feat: azure openai support
This commit is contained in:
parent
47f8b3500b
commit
caeb822cdc
@ -463,60 +463,88 @@ async def get_models(
|
|||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||||
|
|
||||||
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||||
|
str(url_idx),
|
||||||
|
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||||
|
)
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
trust_env=True,
|
trust_env=True,
|
||||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
headers = {
|
||||||
f"{url}/models",
|
"Content-Type": "application/json",
|
||||||
headers={
|
**(
|
||||||
"Authorization": f"Bearer {key}",
|
{
|
||||||
"Content-Type": "application/json",
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
**(
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
{
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
"X-OpenWebUI-User-Name": user.name,
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
}
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
else {}
|
||||||
}
|
),
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
}
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
||||||
) as r:
|
|
||||||
if r.status != 200:
|
|
||||||
# Extract response error details if available
|
|
||||||
error_detail = f"HTTP Error: {r.status}"
|
|
||||||
res = await r.json()
|
|
||||||
if "error" in res:
|
|
||||||
error_detail = f"External Error: {res['error']}"
|
|
||||||
raise Exception(error_detail)
|
|
||||||
|
|
||||||
response_data = await r.json()
|
if api_config.get("azure", False):
|
||||||
|
headers["api-key"] = key
|
||||||
|
|
||||||
# Check if we're calling OpenAI API based on the URL
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||||
if "api.openai.com" in url:
|
async with session.get(
|
||||||
# Filter models according to the specified conditions
|
f"{url}/openai/deployments?api-version={api_version}",
|
||||||
response_data["data"] = [
|
headers=headers,
|
||||||
model
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
for model in response_data.get("data", [])
|
) as r:
|
||||||
if not any(
|
if r.status != 200:
|
||||||
name in model["id"]
|
# Extract response error details if available
|
||||||
for name in [
|
error_detail = f"HTTP Error: {r.status}"
|
||||||
"babbage",
|
res = await r.json()
|
||||||
"dall-e",
|
if "error" in res:
|
||||||
"davinci",
|
error_detail = f"External Error: {res['error']}"
|
||||||
"embedding",
|
raise Exception(error_detail)
|
||||||
"tts",
|
|
||||||
"whisper",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
models = response_data
|
response_data = await r.json()
|
||||||
|
models = response_data
|
||||||
|
else:
|
||||||
|
headers["Authorization"] = f"Bearer {key}"
|
||||||
|
|
||||||
|
async with session.get(
|
||||||
|
f"{url}/models",
|
||||||
|
headers=headers,
|
||||||
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
|
) as r:
|
||||||
|
if r.status != 200:
|
||||||
|
# Extract response error details if available
|
||||||
|
error_detail = f"HTTP Error: {r.status}"
|
||||||
|
res = await r.json()
|
||||||
|
if "error" in res:
|
||||||
|
error_detail = f"External Error: {res['error']}"
|
||||||
|
raise Exception(error_detail)
|
||||||
|
|
||||||
|
response_data = await r.json()
|
||||||
|
|
||||||
|
# Check if we're calling OpenAI API based on the URL
|
||||||
|
if "api.openai.com" in url:
|
||||||
|
# Filter models according to the specified conditions
|
||||||
|
response_data["data"] = [
|
||||||
|
model
|
||||||
|
for model in response_data.get("data", [])
|
||||||
|
if not any(
|
||||||
|
name in model["id"]
|
||||||
|
for name in [
|
||||||
|
"babbage",
|
||||||
|
"dall-e",
|
||||||
|
"davinci",
|
||||||
|
"embedding",
|
||||||
|
"tts",
|
||||||
|
"whisper",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
models = response_data
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
# ClientError covers all aiohttp requests issues
|
# ClientError covers all aiohttp requests issues
|
||||||
log.exception(f"Client error: {str(e)}")
|
log.exception(f"Client error: {str(e)}")
|
||||||
@ -538,6 +566,8 @@ class ConnectionVerificationForm(BaseModel):
|
|||||||
url: str
|
url: str
|
||||||
key: str
|
key: str
|
||||||
|
|
||||||
|
config: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/verify")
|
@router.post("/verify")
|
||||||
async def verify_connection(
|
async def verify_connection(
|
||||||
@ -546,39 +576,64 @@ async def verify_connection(
|
|||||||
url = form_data.url
|
url = form_data.url
|
||||||
key = form_data.key
|
key = form_data.key
|
||||||
|
|
||||||
|
api_config = form_data.config or {}
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
trust_env=True,
|
trust_env=True,
|
||||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
async with session.get(
|
headers = {
|
||||||
f"{url}/models",
|
"Content-Type": "application/json",
|
||||||
headers={
|
**(
|
||||||
"Authorization": f"Bearer {key}",
|
{
|
||||||
"Content-Type": "application/json",
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
**(
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
{
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
"X-OpenWebUI-User-Name": user.name,
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
}
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
else {}
|
||||||
}
|
),
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
}
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
||||||
) as r:
|
|
||||||
if r.status != 200:
|
|
||||||
# Extract response error details if available
|
|
||||||
error_detail = f"HTTP Error: {r.status}"
|
|
||||||
res = await r.json()
|
|
||||||
if "error" in res:
|
|
||||||
error_detail = f"External Error: {res['error']}"
|
|
||||||
raise Exception(error_detail)
|
|
||||||
|
|
||||||
response_data = await r.json()
|
if api_config.get("azure", False):
|
||||||
return response_data
|
headers["api-key"] = key
|
||||||
|
|
||||||
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||||
|
async with session.get(
|
||||||
|
f"{url}/openai/deployments?api-version={api_version}",
|
||||||
|
headers=headers,
|
||||||
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
|
) as r:
|
||||||
|
if r.status != 200:
|
||||||
|
# Extract response error details if available
|
||||||
|
error_detail = f"HTTP Error: {r.status}"
|
||||||
|
res = await r.json()
|
||||||
|
if "error" in res:
|
||||||
|
error_detail = f"External Error: {res['error']}"
|
||||||
|
raise Exception(error_detail)
|
||||||
|
|
||||||
|
response_data = await r.json()
|
||||||
|
return response_data
|
||||||
|
else:
|
||||||
|
headers["Authorization"] = f"Bearer {key}"
|
||||||
|
|
||||||
|
async with session.get(
|
||||||
|
f"{url}/models",
|
||||||
|
headers=headers,
|
||||||
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
|
) as r:
|
||||||
|
if r.status != 200:
|
||||||
|
# Extract response error details if available
|
||||||
|
error_detail = f"HTTP Error: {r.status}"
|
||||||
|
res = await r.json()
|
||||||
|
if "error" in res:
|
||||||
|
error_detail = f"External Error: {res['error']}"
|
||||||
|
raise Exception(error_detail)
|
||||||
|
|
||||||
|
response_data = await r.json()
|
||||||
|
return response_data
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
# ClientError covers all aiohttp requests issues
|
# ClientError covers all aiohttp requests issues
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
import { OLLAMA_API_BASE_URL } from '$lib/constants';
|
import { OLLAMA_API_BASE_URL } from '$lib/constants';
|
||||||
|
|
||||||
export const verifyOllamaConnection = async (
|
export const verifyOllamaConnection = async (token: string = '', connection: dict = {}) => {
|
||||||
token: string = '',
|
|
||||||
url: string = '',
|
|
||||||
key: string = ''
|
|
||||||
) => {
|
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${OLLAMA_API_BASE_URL}/verify`, {
|
const res = await fetch(`${OLLAMA_API_BASE_URL}/verify`, {
|
||||||
@ -15,8 +11,7 @@ export const verifyOllamaConnection = async (
|
|||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
url,
|
...connection
|
||||||
key
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(async (res) => {
|
.then(async (res) => {
|
||||||
|
@ -267,10 +267,10 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => {
|
|||||||
|
|
||||||
export const verifyOpenAIConnection = async (
|
export const verifyOpenAIConnection = async (
|
||||||
token: string = '',
|
token: string = '',
|
||||||
url: string = 'https://api.openai.com/v1',
|
connection: dict = {},
|
||||||
key: string = '',
|
|
||||||
direct: boolean = false
|
direct: boolean = false
|
||||||
) => {
|
) => {
|
||||||
|
const { url, key, config } = connection;
|
||||||
if (!url) {
|
if (!url) {
|
||||||
throw 'OpenAI: URL is required';
|
throw 'OpenAI: URL is required';
|
||||||
}
|
}
|
||||||
@ -309,7 +309,8 @@ export const verifyOpenAIConnection = async (
|
|||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
url,
|
url,
|
||||||
key
|
key,
|
||||||
|
config
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(async (res) => {
|
.then(async (res) => {
|
||||||
|
@ -33,7 +33,9 @@
|
|||||||
let connectionType = 'external';
|
let connectionType = 'external';
|
||||||
let azure = false;
|
let azure = false;
|
||||||
$: azure =
|
$: azure =
|
||||||
url.includes('openai.azure.com') || url.includes('cognitive.microsoft.com') ? true : false;
|
(url.includes('openai.azure.com') || url.includes('cognitive.microsoft.com')) && !direct
|
||||||
|
? true
|
||||||
|
: false;
|
||||||
|
|
||||||
let prefixId = '';
|
let prefixId = '';
|
||||||
let enable = true;
|
let enable = true;
|
||||||
@ -47,7 +49,10 @@
|
|||||||
let loading = false;
|
let loading = false;
|
||||||
|
|
||||||
const verifyOllamaHandler = async () => {
|
const verifyOllamaHandler = async () => {
|
||||||
const res = await verifyOllamaConnection(localStorage.token, url, key).catch((error) => {
|
const res = await verifyOllamaConnection(localStorage.token, {
|
||||||
|
url,
|
||||||
|
key
|
||||||
|
}).catch((error) => {
|
||||||
toast.error(`${error}`);
|
toast.error(`${error}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -57,11 +62,20 @@
|
|||||||
};
|
};
|
||||||
|
|
||||||
const verifyOpenAIHandler = async () => {
|
const verifyOpenAIHandler = async () => {
|
||||||
const res = await verifyOpenAIConnection(localStorage.token, url, key, direct).catch(
|
const res = await verifyOpenAIConnection(
|
||||||
(error) => {
|
localStorage.token,
|
||||||
toast.error(`${error}`);
|
{
|
||||||
}
|
url,
|
||||||
);
|
key,
|
||||||
|
config: {
|
||||||
|
azure: azure,
|
||||||
|
api_version: apiVersion
|
||||||
|
}
|
||||||
|
},
|
||||||
|
direct
|
||||||
|
).catch((error) => {
|
||||||
|
toast.error(`${error}`);
|
||||||
|
});
|
||||||
|
|
||||||
if (res) {
|
if (res) {
|
||||||
toast.success($i18n.t('Server connection verified'));
|
toast.success($i18n.t('Server connection verified'));
|
||||||
@ -187,27 +201,29 @@
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<div class="px-1">
|
<div class="px-1">
|
||||||
<div class="flex gap-2">
|
{#if !direct}
|
||||||
<div class="flex w-full justify-between items-center">
|
<div class="flex gap-2">
|
||||||
<div class=" text-xs text-gray-500">{$i18n.t('Connection Type')}</div>
|
<div class="flex w-full justify-between items-center">
|
||||||
|
<div class=" text-xs text-gray-500">{$i18n.t('Connection Type')}</div>
|
||||||
|
|
||||||
<div class="">
|
<div class="">
|
||||||
<button
|
<button
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
connectionType = connectionType === 'local' ? 'external' : 'local';
|
connectionType = connectionType === 'local' ? 'external' : 'local';
|
||||||
}}
|
}}
|
||||||
type="button"
|
type="button"
|
||||||
class=" text-xs text-gray-700 dark:text-gray-300"
|
class=" text-xs text-gray-700 dark:text-gray-300"
|
||||||
>
|
>
|
||||||
{#if connectionType === 'local'}
|
{#if connectionType === 'local'}
|
||||||
{$i18n.t('Local')}
|
{$i18n.t('Local')}
|
||||||
{:else}
|
{:else}
|
||||||
{$i18n.t('External')}
|
{$i18n.t('External')}
|
||||||
{/if}
|
{/if}
|
||||||
</button>
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
{/if}
|
||||||
|
|
||||||
<div class="flex gap-2 mt-1.5">
|
<div class="flex gap-2 mt-1.5">
|
||||||
<div class="flex flex-col w-full">
|
<div class="flex flex-col w-full">
|
||||||
|
@ -3,10 +3,6 @@
|
|||||||
import { getContext, onMount } from 'svelte';
|
import { getContext, onMount } from 'svelte';
|
||||||
const i18n = getContext('i18n');
|
const i18n = getContext('i18n');
|
||||||
|
|
||||||
import { models } from '$lib/stores';
|
|
||||||
import { verifyOpenAIConnection } from '$lib/apis/openai';
|
|
||||||
import { verifyOllamaConnection } from '$lib/apis/ollama';
|
|
||||||
|
|
||||||
import Modal from '$lib/components/common/Modal.svelte';
|
import Modal from '$lib/components/common/Modal.svelte';
|
||||||
import Plus from '$lib/components/icons/Plus.svelte';
|
import Plus from '$lib/components/icons/Plus.svelte';
|
||||||
import Minus from '$lib/components/icons/Minus.svelte';
|
import Minus from '$lib/components/icons/Minus.svelte';
|
||||||
|
Loading…
Reference in New Issue
Block a user