diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 05cd2d223..2038f4302 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -46,7 +46,11 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) app.add_middleware( CORSMiddleware, @@ -90,34 +94,26 @@ async def get_status(): @app.get("/config") async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + return { + "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + } class OllamaConfigForm(BaseModel): - enable_ollama_api: Optional[bool] = None + ENABLE_OLLAMA_API: Optional[bool] = None + OLLAMA_BASE_URLS: list[str] @app.post("/config/update") async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api - return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API} + app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API + app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS - -@app.get("/urls") -async def get_ollama_api_urls(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} - - -class UrlUpdateForm(BaseModel): - urls: list[str] - - -@app.post("/urls/update") -async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.config.OLLAMA_BASE_URLS = form_data.urls - - log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}") - return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} + return { + "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, + } async def fetch_url(url): diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index ed6ad5cbb..ef79c205b 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -16,6 +16,7 @@ from open_webui.config import ( MODEL_FILTER_LIST, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, + OPENAI_API_CONFIGS, AppConfig, ) from open_webui.env import ( @@ -43,7 +44,11 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) app.add_middleware( @@ -62,6 +67,7 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS +app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS app.state.MODELS = {} @@ -77,48 +83,58 @@ async def check_url(request: Request, call_next): @app.get("/config") async def get_config(user=Depends(get_admin_user)): - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + return { + "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + } class OpenAIConfigForm(BaseModel): - enable_openai_api: Optional[bool] = None + ENABLE_OPENAI_API: Optional[bool] = None + OPENAI_API_BASE_URLS: list[str] + OPENAI_API_KEYS: list[str] + OPENAI_API_CONFIGS: dict @app.post("/config/update") async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api - return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API} + app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API + app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS + app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS -class UrlsUpdateForm(BaseModel): - urls: list[str] + # Check if API KEYS length is same than API URLS length + if len(app.state.config.OPENAI_API_KEYS) != len( + app.state.config.OPENAI_API_BASE_URLS + ): + if len(app.state.config.OPENAI_API_KEYS) > len( + app.state.config.OPENAI_API_BASE_URLS + ): + app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ + : len(app.state.config.OPENAI_API_BASE_URLS) + ] + else: + app.state.config.OPENAI_API_KEYS += [""] * ( + len(app.state.config.OPENAI_API_BASE_URLS) + - len(app.state.config.OPENAI_API_KEYS) + ) + app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS -class KeysUpdateForm(BaseModel): - keys: list[str] + # Remove any extra configs + config_urls = app.state.config.OPENAI_API_CONFIGS.keys() + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + if url not in config_urls: + app.state.config.OPENAI_API_CONFIGS.pop(url, None) - -@app.get("/urls") -async def get_openai_urls(user=Depends(get_admin_user)): - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - - -@app.post("/urls/update") -async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): - await get_all_models() - app.state.config.OPENAI_API_BASE_URLS = form_data.urls - return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS} - - -@app.get("/keys") -async def get_openai_keys(user=Depends(get_admin_user)): - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} - - -@app.post("/keys/update") -async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): - app.state.config.OPENAI_API_KEYS = form_data.keys - return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS} + return { + "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + } @app.post("/audio/speech") @@ -190,7 +206,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def fetch_url(url, key): +async def aiohttp_get(url, key): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: headers = {"Authorization": f"Bearer {key}"} @@ -248,12 +264,8 @@ def merge_models_lists(model_lists): return merged_list -def is_openai_api_disabled(): - return not app.state.config.ENABLE_OPENAI_API - - async def get_all_models_raw() -> list: - if is_openai_api_disabled(): + if not app.state.config.ENABLE_OPENAI_API: return [] # Check if API KEYS length is same than API URLS length @@ -269,12 +281,55 @@ async def get_all_models_raw() -> list: else: app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) - tasks = [ - fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS) - ] + tasks = [] + for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + if url not in app.state.config.OPENAI_API_CONFIGS: + tasks.append( + aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + ) + else: + api_config = app.state.config.OPENAI_API_CONFIGS[url] + + enabled = api_config.get("enabled", True) + model_ids = api_config.get("model_ids", []) + + if enabled: + if len(model_ids) == 0: + tasks.append( + aiohttp_get( + f"{url}/models", app.state.config.OPENAI_API_KEYS[idx] + ) + ) + else: + model_list = { + "object": "list", + "data": [ + { + "id": model_id, + "name": model_id, + "owned_by": "openai", + "openai": {"id": model_id}, + "urlIdx": idx, + } + for model_id in model_ids + ], + } + + tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) responses = await asyncio.gather(*tasks) + + for idx, response in enumerate(responses): + if response: + url = app.state.config.OPENAI_API_BASE_URLS[idx] + api_config = app.state.config.OPENAI_API_CONFIGS[url] + + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + for model in response["data"]: + model["id"] = f"{prefix_id}.{model['id']}" + log.debug(f"get_all_models:responses() {responses}") return responses @@ -290,7 +345,7 @@ async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ... async def get_all_models(raw=False) -> dict[str, list] | list: log.info("get_all_models()") - if is_openai_api_disabled(): + if not app.state.config.ENABLE_OPENAI_API: return [] if raw else {"data": []} responses = await get_all_models_raw() @@ -342,7 +397,6 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us r = None - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) async with aiohttp.ClientSession(timeout=timeout) as session: try: @@ -361,7 +415,8 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us if "api.openai.com" in url: # Filter models according to the specified conditions response_data["data"] = [ - model for model in response_data.get("data", []) + model + for model in response_data.get("data", []) if not any( name in model["id"] for name in [ @@ -381,7 +436,9 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") # Handle aiohttp-specific connection issues, timeout etc. - raise HTTPException(status_code=500, detail="Open WebUI: Server Connection Error") + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) except Exception as e: log.exception(f"Unexpected error: {e}") # Generic error handler in case parsing JSON or other steps fail @@ -389,6 +446,49 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us raise HTTPException(status_code=500, detail=error_detail) +class ConnectionVerificationForm(BaseModel): + url: str + key: str + + +@app.post("/verify") +async def verify_connection( + form_data: ConnectionVerificationForm, user=Depends(get_admin_user) +): + url = form_data.url + key = form_data.key + + headers = {} + headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get(f"{url}/models", headers=headers) as r: + if r.status != 200: + # Extract response error details if available + error_detail = f"HTTP Error: {r.status}" + res = await r.json() + if "error" in res: + error_detail = f"External Error: {res['error']}" + raise Exception(error_detail) + + response_data = await r.json() + return response_data + + except aiohttp.ClientError as e: + # ClientError covers all aiohttp requests issues + log.exception(f"Client error: {str(e)}") + # Handle aiohttp-specific connection issues, timeout etc. + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) + except Exception as e: + log.exception(f"Unexpected error: {e}") + # Generic error handler in case parsing JSON or other steps fail + error_detail = f"Unexpected error: {str(e)}" + raise HTTPException(status_code=500, detail=error_detail) @app.post("/chat/completions") @@ -418,6 +518,14 @@ async def generate_chat_completion( model = app.state.MODELS[payload.get("model")] idx = model["urlIdx"] + api_config = app.state.config.OPENAI_API_CONFIGS.get( + app.state.config.OPENAI_API_BASE_URLS[idx], {} + ) + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + if "pipeline" in model and model.get("pipeline"): payload["user"] = { "name": user.name, diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index cd1a55e26..7da3b0d6f 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -607,6 +607,12 @@ OLLAMA_BASE_URLS = PersistentConfig( "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS ) +OLLAMA_API_CONFIGS = PersistentConfig( + "OLLAMA_API_CONFIGS", + "ollama.api_configs", + {}, +) + #################################### # OPENAI_API #################################### @@ -647,15 +653,20 @@ OPENAI_API_BASE_URLS = PersistentConfig( "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS ) -OPENAI_API_KEY = "" +OPENAI_API_CONFIGS = PersistentConfig( + "OPENAI_API_CONFIGS", + "openai.api_configs", + {}, +) +# Get the actual OpenAI API key based on the base URL +OPENAI_API_KEY = "" try: OPENAI_API_KEY = OPENAI_API_KEYS.value[ OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") ] except Exception: pass - OPENAI_API_BASE_URL = "https://api.openai.com/v1" #################################### diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index d4e994312..fcb37f7a4 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -32,7 +32,13 @@ export const getOllamaConfig = async (token: string = '') => { return res; }; -export const updateOllamaConfig = async (token: string = '', enable_ollama_api: boolean) => { +type OllamaConfig = { + ENABLE_OLLAMA_API: boolean, + OLLAMA_BASE_URLS: string[], + OLLAMA_API_CONFIGS: object +} + +export const updateOllamaConfig = async (token: string = '', config: OllamaConfig) => { let error = null; const res = await fetch(`${OLLAMA_API_BASE_URL}/config/update`, { @@ -43,7 +49,7 @@ export const updateOllamaConfig = async (token: string = '', enable_ollama_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_ollama_api: enable_ollama_api + ...config }) }) .then(async (res) => { diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 2bb11d12a..a9f581249 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -32,7 +32,17 @@ export const getOpenAIConfig = async (token: string = '') => { return res; }; -export const updateOpenAIConfig = async (token: string = '', enable_openai_api: boolean) => { + +type OpenAIConfig = { + ENABLE_OPENAI_API: boolean; + OPENAI_API_BASE_URLS: string[]; + OPENAI_API_KEYS: string[]; + OPENAI_API_CONFIGS: object; +} + + + +export const updateOpenAIConfig = async (token: string = '', config: OpenAIConfig) => { let error = null; const res = await fetch(`${OPENAI_API_BASE_URL}/config/update`, { @@ -43,7 +53,7 @@ export const updateOpenAIConfig = async (token: string = '', enable_openai_api: ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - enable_openai_api: enable_openai_api + ...config }) }) .then(async (res) => { @@ -99,6 +109,7 @@ export const getOpenAIUrls = async (token: string = '') => { return res.OPENAI_API_BASE_URLS; }; + export const updateOpenAIUrls = async (token: string = '', urls: string[]) => { let error = null; @@ -231,41 +242,43 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => { return res; }; -export const getOpenAIModelsDirect = async ( - base_url: string = 'https://api.openai.com/v1', - api_key: string = '' +export const verifyOpenAIConnection = async ( + token: string = '', + url: string = 'https://api.openai.com/v1', + key: string = '' ) => { let error = null; - const res = await fetch(`${base_url}/models`, { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${api_key}` + const res = await fetch( + `${OPENAI_API_BASE_URL}/verify`, + { + method: 'POST', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json', + + }, + body: JSON.stringify({ + url, + key + }) } - }) + ) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); }) .catch((err) => { - console.log(err); error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; - return null; + return []; }); if (error) { throw error; } - const models = Array.isArray(res) ? res : (res?.data ?? null); - - return models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) - .filter((model) => (base_url.includes('openai') ? model.name.includes('gpt') : true)) - .sort((a, b) => { - return a.name.localeCompare(b.name); - }); + return res; }; export const generateOpenAIChatCompletion = async ( diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 7e1e489b1..20af84cc0 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -39,7 +39,7 @@ }); -