diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py
index 2038f4302..67268a95b 100644
--- a/backend/open_webui/apps/ollama/main.py
+++ b/backend/open_webui/apps/ollama/main.py
@@ -17,10 +17,14 @@ from open_webui.config import (
ENABLE_OLLAMA_API,
MODEL_FILTER_LIST,
OLLAMA_BASE_URLS,
+ OLLAMA_API_CONFIGS,
UPLOAD_DIR,
AppConfig,
)
-from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
+from open_webui.env import (
+ AIOHTTP_CLIENT_TIMEOUT,
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
+)
from open_webui.constants import ERROR_MESSAGES
@@ -67,6 +71,8 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
+app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
+
app.state.MODELS = {}
@@ -92,17 +98,64 @@ async def get_status():
return {"status": True}
+class ConnectionVerificationForm(BaseModel):
+ url: str
+ key: Optional[str] = None
+
+
+@app.post("/verify")
+async def verify_connection(
+ form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
+):
+ url = form_data.url
+ key = form_data.key
+
+ headers = {}
+ if key:
+ headers["Authorization"] = f"Bearer {key}"
+
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
+ async with aiohttp.ClientSession(timeout=timeout) as session:
+ try:
+ async with session.get(f"{url}/api/version", headers=headers) as r:
+ if r.status != 200:
+ # Extract response error details if available
+ error_detail = f"HTTP Error: {r.status}"
+ res = await r.json()
+ if "error" in res:
+ error_detail = f"External Error: {res['error']}"
+ raise Exception(error_detail)
+
+ response_data = await r.json()
+ return response_data
+
+ except aiohttp.ClientError as e:
+ # ClientError covers all aiohttp requests issues
+ log.exception(f"Client error: {str(e)}")
+ # Handle aiohttp-specific connection issues, timeout etc.
+ raise HTTPException(
+ status_code=500, detail="Open WebUI: Server Connection Error"
+ )
+ except Exception as e:
+ log.exception(f"Unexpected error: {e}")
+ # Generic error handler in case parsing JSON or other steps fail
+ error_detail = f"Unexpected error: {str(e)}"
+ raise HTTPException(status_code=500, detail=error_detail)
+
+
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
}
class OllamaConfigForm(BaseModel):
ENABLE_OLLAMA_API: Optional[bool] = None
OLLAMA_BASE_URLS: list[str]
+ OLLAMA_API_CONFIGS: dict
@app.post("/config/update")
@@ -110,17 +163,27 @@ async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
+ app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
+
+ # Remove any extra configs
+ config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
+ for url in list(app.state.config.OLLAMA_BASE_URLS):
+ if url not in config_urls:
+ app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
+
return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
}
-async def fetch_url(url):
- timeout = aiohttp.ClientTimeout(total=3)
+async def aiohttp_get(url, key=None):
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
+ headers = {"Authorization": f"Bearer {key}"} if key else {}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
- async with session.get(url) as response:
+ async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
@@ -204,13 +267,42 @@ def merge_models_lists(model_lists):
async def get_all_models():
log.info("get_all_models()")
-
if app.state.config.ENABLE_OLLAMA_API:
- tasks = [
- fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
- ]
+ tasks = []
+ for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
+ if url not in app.state.config.OLLAMA_API_CONFIGS:
+ tasks.append(aiohttp_get(f"{url}/api/tags"))
+ else:
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+ enable = api_config.get("enable", True)
+
+ if enable:
+ tasks.append(aiohttp_get(f"{url}/api/tags"))
+ else:
+ tasks.append(None)
+
responses = await asyncio.gather(*tasks)
+ for idx, response in enumerate(responses):
+ if response:
+ url = app.state.config.OLLAMA_BASE_URLS[idx]
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+
+ prefix_id = api_config.get("prefix_id", None)
+ model_ids = api_config.get("model_ids", [])
+
+ if len(model_ids) != 0:
+ response["models"] = list(
+ filter(
+ lambda model: model["model"] in model_ids,
+ response["models"],
+ )
+ )
+
+ if prefix_id:
+ for model in response["models"]:
+ model["model"] = f"{prefix_id}.{model['model']}"
+
models = {
"models": merge_models_lists(
map(
@@ -279,7 +371,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx is None:
# returns lowest version
tasks = [
- fetch_url(f"{url}/api/version")
+ aiohttp_get(f"{url}/api/version")
for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
@@ -718,6 +810,10 @@ async def generate_completion(
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+ prefix_id = api_config.get("prefix_id", None)
+ if prefix_id:
+ form_data.model = form_data.model.replace(f"{prefix_id}.", "")
log.info(f"url: {url}")
return await post_streaming_url(
@@ -799,6 +895,11 @@ async def generate_chat_completion(
log.info(f"url: {url}")
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+ prefix_id = api_config.get("prefix_id", None)
+ if prefix_id:
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
+
return await post_streaming_url(
f"{url}/api/chat",
json.dumps(payload),
@@ -874,6 +975,11 @@ async def generate_openai_chat_completion(
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+ prefix_id = api_config.get("prefix_id", None)
+ if prefix_id:
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
+
return await post_streaming_url(
f"{url}/v1/chat/completions",
json.dumps(payload),
diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py
index 9c7b37911..5a4dba62f 100644
--- a/backend/open_webui/apps/openai/main.py
+++ b/backend/open_webui/apps/openai/main.py
@@ -206,10 +206,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
-async def aiohttp_get(url, key):
+async def aiohttp_get(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
- headers = {"Authorization": f"Bearer {key}"}
+ headers = {"Authorization": f"Bearer {key}"} if key else {}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts
index fcb37f7a4..4ad2ba596 100644
--- a/src/lib/apis/ollama/index.ts
+++ b/src/lib/apis/ollama/index.ts
@@ -1,5 +1,44 @@
import { OLLAMA_API_BASE_URL } from '$lib/constants';
+
+export const verifyOllamaConnection = async (
+ token: string = '',
+ url: string = '',
+ key: string = ''
+) => {
+ let error = null;
+
+ const res = await fetch(
+ `${OLLAMA_API_BASE_URL}/verify`,
+ {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ Authorization: `Bearer ${token}`,
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ url,
+ key
+ })
+ }
+ )
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ error = `Ollama: ${err?.error?.message ?? 'Network Problem'}`;
+ return [];
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
export const getOllamaConfig = async (token: string = '') => {
let error = null;
diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte
index 63f603807..6031b591c 100644
--- a/src/lib/components/admin/Settings/Connections.svelte
+++ b/src/lib/components/admin/Settings/Connections.svelte
@@ -13,10 +13,11 @@
import Switch from '$lib/components/common/Switch.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
+ import Plus from '$lib/components/icons/Plus.svelte';
import OpenAIConnection from './Connections/OpenAIConnection.svelte';
- import OpenAIConnectionModal from './Connections/OpenAIConnectionModal.svelte';
- import Plus from '$lib/components/icons/Plus.svelte';
+ import AddConnectionModal from './Connections/AddConnectionModal.svelte';
+ import OllamaConnection from './Connections/OllamaConnection.svelte';
const i18n = getContext('i18n');
@@ -38,10 +39,14 @@
let pipelineUrls = {};
let showAddOpenAIConnectionModal = false;
+ let showAddOllamaConnectionModal = false;
const updateOpenAIHandler = async () => {
if (ENABLE_OPENAI_API !== null) {
- OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, ''));
+ OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter(
+ (url, urlIdx) => OPENAI_API_BASE_URLS.indexOf(url) === urlIdx && url !== ''
+ ).map((url) => url.replace(/\/$/, ''));
+
// Check if API KEYS length is same than API URLS length
if (OPENAI_API_KEYS.length !== OPENAI_API_BASE_URLS.length) {
// if there are more keys than urls, remove the extra keys
@@ -76,9 +81,10 @@
const updateOllamaHandler = async () => {
if (ENABLE_OLLAMA_API !== null) {
- OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== '').map((url) =>
- url.replace(/\/$/, '')
- );
+ // Remove duplicate URLs
+ OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter(
+ (url, urlIdx) => OLLAMA_BASE_URLS.indexOf(url) === urlIdx && url !== ''
+ ).map((url) => url.replace(/\/$/, ''));
console.log(OLLAMA_BASE_URLS);
@@ -110,6 +116,13 @@
await updateOpenAIHandler();
};
+ const addOllamaConnectionHandler = async (connection) => {
+ OLLAMA_BASE_URLS = [...OLLAMA_BASE_URLS, connection.url];
+ OLLAMA_API_CONFIGS[connection.url] = connection.config;
+
+ await updateOllamaHandler();
+ };
+
onMount(async () => {
if ($user.role === 'admin') {
let ollamaConfig = {};
@@ -160,11 +173,17 @@
});
-