diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 2038f4302..67268a95b 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -17,10 +17,14 @@ from open_webui.config import ( ENABLE_OLLAMA_API, MODEL_FILTER_LIST, OLLAMA_BASE_URLS, + OLLAMA_API_CONFIGS, UPLOAD_DIR, AppConfig, ) -from open_webui.env import AIOHTTP_CLIENT_TIMEOUT +from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, +) from open_webui.constants import ERROR_MESSAGES @@ -67,6 +71,8 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS +app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS + app.state.MODELS = {} @@ -92,17 +98,64 @@ async def get_status(): return {"status": True} +class ConnectionVerificationForm(BaseModel): + url: str + key: Optional[str] = None + + +@app.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + headers = {} + if key: + headers["Authorization"] = f"Bearer {key}" + + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get(f"{url}/api/version", headers=headers) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data + + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + # Handle aiohttp-specific connection issues, timeout etc. + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + # Generic error handler in case parsing JSON or other steps fail + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) + + @app.get("/config") async def get_config(user=Depends(get_admin_user)): return { "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, } class OllamaConfigForm(BaseModel): ENABLE_OLLAMA_API: Optional[bool] = None OLLAMA_BASE_URLS: list[str] + OLLAMA_API_CONFIGS: dict @app.post("/config/update") @@ -110,17 +163,27 @@ async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS + app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS + + # Remove any extra configs + config_urls = app.state.config.OLLAMA_API_CONFIGS.keys() + for url in list(app.state.config.OLLAMA_BASE_URLS): + if url not in config_urls: + app.state.config.OLLAMA_API_CONFIGS.pop(url, None) + return { "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, } -async def fetch_url(url): - timeout = aiohttp.ClientTimeout(total=3) +async def aiohttp_get(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: + headers = {"Authorization": f"Bearer {key}"} if key else {} async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url) as response: + async with session.get(url, headers=headers) as response: return await response.json() except Exception as e: # Handle connection error here @@ -204,13 +267,42 @@ def merge_models_lists(model_lists): async def get_all_models(): log.info("get_all_models()") - if app.state.config.ENABLE_OLLAMA_API: - tasks = [ - fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS - ] + tasks = [] + for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS): + if url not in app.state.config.OLLAMA_API_CONFIGS: + tasks.append(aiohttp_get(f"{url}/api/tags")) + else: + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + enable = api_config.get("enable", True) + + if enable: + tasks.append(aiohttp_get(f"{url}/api/tags")) + else: + tasks.append(None) + responses = await asyncio.gather(*tasks) + for idx, response in enumerate(responses): + if response: + url = app.state.config.OLLAMA_BASE_URLS[idx] + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + + prefix_id = api_config.get("prefix_id", None) + model_ids = api_config.get("model_ids", []) + + if len(model_ids) != 0: + response["models"] = list( + filter( + lambda model: model["model"] in model_ids, + response["models"], + ) + ) + + if prefix_id: + for model in response["models"]: + model["model"] = f"{prefix_id}.{model['model']}" + models = { "models": merge_models_lists( map( @@ -279,7 +371,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): if url_idx is None: # returns lowest version tasks = [ - fetch_url(f"{url}/api/version") + aiohttp_get(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) @@ -718,6 +810,10 @@ async def generate_completion( ) url = app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + form_data.model = form_data.model.replace(f"{prefix_id}.", "") log.info(f"url: {url}") return await post_streaming_url( @@ -799,6 +895,11 @@ async def generate_chat_completion( log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + return await post_streaming_url( f"{url}/api/chat", json.dumps(payload), @@ -874,6 +975,11 @@ async def generate_openai_chat_completion( url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + return await post_streaming_url( f"{url}/v1/chat/completions", json.dumps(payload), diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 9c7b37911..5a4dba62f 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -206,10 +206,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def aiohttp_get(url, key): +async def aiohttp_get(url, key=None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: - headers = {"Authorization": f"Bearer {key}"} + headers = {"Authorization": f"Bearer {key}"} if key else {} async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get(url, headers=headers) as response: return await response.json() diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index fcb37f7a4..4ad2ba596 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,5 +1,44 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; + +export const verifyOllamaConnection = async ( + token: string = '', + url: string = '', + key: string = '' +) => { + let error = null; + + const res = await fetch( + `${OLLAMA_API_BASE_URL}/verify`, + { + method: 'POST', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + url, + key + }) + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `Ollama: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getOllamaConfig = async (token: string = '') => { let error = null; diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index 63f603807..6031b591c 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -13,10 +13,11 @@ import Switch from '$lib/components/common/Switch.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; + import Plus from '$lib/components/icons/Plus.svelte'; import OpenAIConnection from './Connections/OpenAIConnection.svelte'; - import OpenAIConnectionModal from './Connections/OpenAIConnectionModal.svelte'; - import Plus from '$lib/components/icons/Plus.svelte'; + import AddConnectionModal from './Connections/AddConnectionModal.svelte'; + import OllamaConnection from './Connections/OllamaConnection.svelte'; const i18n = getContext('i18n'); @@ -38,10 +39,14 @@ let pipelineUrls = {}; let showAddOpenAIConnectionModal = false; + let showAddOllamaConnectionModal = false; const updateOpenAIHandler = async () => { if (ENABLE_OPENAI_API !== null) { - OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, '')); + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter( + (url, urlIdx) => OPENAI_API_BASE_URLS.indexOf(url) === urlIdx && url !== '' + ).map((url) => url.replace(/\/$/, '')); + // Check if API KEYS length is same than API URLS length if (OPENAI_API_KEYS.length !== OPENAI_API_BASE_URLS.length) { // if there are more keys than urls, remove the extra keys @@ -76,9 +81,10 @@ const updateOllamaHandler = async () => { if (ENABLE_OLLAMA_API !== null) { - OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== '').map((url) => - url.replace(/\/$/, '') - ); + // Remove duplicate URLs + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter( + (url, urlIdx) => OLLAMA_BASE_URLS.indexOf(url) === urlIdx && url !== '' + ).map((url) => url.replace(/\/$/, '')); console.log(OLLAMA_BASE_URLS); @@ -110,6 +116,13 @@ await updateOpenAIHandler(); }; + const addOllamaConnectionHandler = async (connection) => { + OLLAMA_BASE_URLS = [...OLLAMA_BASE_URLS, connection.url]; + OLLAMA_API_CONFIGS[connection.url] = connection.config; + + await updateOllamaHandler(); + }; + onMount(async () => { if ($user.role === 'admin') { let ollamaConfig = {}; @@ -160,11 +173,17 @@ }); - + +
{ @@ -219,7 +238,7 @@ pipeline={pipelineUrls[url] ? true : false} bind:url bind:key={OPENAI_API_KEYS[idx]} - bind:config={OPENAI_API_CONFIGS[OPENAI_API_BASE_URLS[idx]]} + bind:config={OPENAI_API_CONFIGS[url]} onSubmit={() => { updateOpenAIHandler(); }} @@ -247,11 +266,7 @@ { - updateOllamaConfig(localStorage.token, ENABLE_OLLAMA_API); - - if (OLLAMA_BASE_URLS.length === 0) { - OLLAMA_BASE_URLS = ['']; - } + updateOllamaHandler(); }} /> @@ -261,85 +276,35 @@
-
+
{$i18n.t('Manage Ollama API Connections')}
- + + +
-
+
{#each OLLAMA_BASE_URLS as url, idx} -
- - -
- - - - - - - -
-
+ { + updateOllamaHandler(); + }} + onDelete={() => { + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx); + }} + /> {/each}
diff --git a/src/lib/components/admin/Settings/Connections/OpenAIConnectionModal.svelte b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte similarity index 91% rename from src/lib/components/admin/Settings/Connections/OpenAIConnectionModal.svelte rename to src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte index c27567412..3f24dc6d7 100644 --- a/src/lib/components/admin/Settings/Connections/OpenAIConnectionModal.svelte +++ b/src/lib/components/admin/Settings/Connections/AddConnectionModal.svelte @@ -5,6 +5,7 @@ import { models } from '$lib/stores'; import { verifyOpenAIConnection } from '$lib/apis/openai'; + import { verifyOllamaConnection } from '$lib/apis/ollama'; import Modal from '$lib/components/common/Modal.svelte'; import Plus from '$lib/components/icons/Plus.svelte'; @@ -19,6 +20,7 @@ export let show = false; export let edit = false; + export let ollama = false; export let connection = null; @@ -33,6 +35,16 @@ let loading = false; + const verifyOllamaHandler = async () => { + const res = await verifyOllamaConnection(localStorage.token, url, key).catch((error) => { + toast.error(error); + }); + + if (res) { + toast.success($i18n.t('Server connection verified')); + } + }; + const verifyOpenAIHandler = async () => { const res = await verifyOpenAIConnection(localStorage.token, url, key).catch((error) => { toast.error(error); @@ -43,6 +55,14 @@ } }; + const verifyHandler = () => { + if (ollama) { + verifyOllamaHandler(); + } else { + verifyOpenAIHandler(); + } + }; + const addModelHandler = () => { if (modelId) { modelIds = [...modelIds, modelId]; @@ -53,7 +73,7 @@ const submitHandler = async () => { loading = true; - if (!url || !key) { + if (!ollama && (!url || !key)) { loading = false; toast.error('URL and Key are required'); return; @@ -159,7 +179,7 @@
@@ -249,9 +269,15 @@
{:else}
- {$i18n.t('Leave empty to include all models from "{{URL}}/models" endpoint', { - URL: url - })} + {#if ollama} + {$i18n.t('Leave empty to include all models from "{{URL}}/api/tags" endpoint', { + URL: url + })} + {:else} + {$i18n.t('Leave empty to include all models from "{{URL}}/models" endpoint', { + URL: url + })} + {/if}
{/if} diff --git a/src/lib/components/admin/Settings/Connections/OllamaConnection.svelte b/src/lib/components/admin/Settings/Connections/OllamaConnection.svelte new file mode 100644 index 000000000..852251a6d --- /dev/null +++ b/src/lib/components/admin/Settings/Connections/OllamaConnection.svelte @@ -0,0 +1,70 @@ + + + { + url = connection.url; + config = { ...connection.config, key: connection.key }; + onSubmit(connection); + }} +/> + +
+ + {#if !(config?.enable ?? true)} +
+ {/if} + + +
+ +
+ + + +
+
diff --git a/src/lib/components/admin/Settings/Connections/OpenAIConnection.svelte b/src/lib/components/admin/Settings/Connections/OpenAIConnection.svelte index fe8cfb809..1e3f5a69b 100644 --- a/src/lib/components/admin/Settings/Connections/OpenAIConnection.svelte +++ b/src/lib/components/admin/Settings/Connections/OpenAIConnection.svelte @@ -5,7 +5,7 @@ import Tooltip from '$lib/components/common/Tooltip.svelte'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import Cog6 from '$lib/components/icons/Cog6.svelte'; - import OpenAIConnectionModal from './OpenAIConnectionModal.svelte'; + import AddConnectionModal from './AddConnectionModal.svelte'; import { connect } from 'socket.io-client'; export let onDelete = () => {}; @@ -20,7 +20,7 @@ let showConfigModal = false; - { - OLLAMA_URLS = await getOllamaUrls(localStorage.token).catch((error) => { - toast.error(error); - return []; - }); + OLLAMA_BASE_URLS = ollamaConfig.OLLAMA_BASE_URLS; - if (OLLAMA_URLS.length > 0) { - selectedOllamaUrlIdx = 0; - } - })(), - (async () => { - ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false); - })() - ]); + if (OLLAMA_BASE_URLS.length > 0) { + selectedOllamaUrlIdx = 0; + } + + ollamaVersion = true; } else { ollamaEnabled = false; toast.error($i18n.t('Ollama API is disabled')); @@ -568,7 +560,7 @@
{$i18n.t('Manage Ollama Models')}
- {#if OLLAMA_URLS.length > 0} + {#if OLLAMA_BASE_URLS.length > 0}