From ca0285fc91e1bf0feb25d8e37d52bf37fe7e37ac Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 18 Jan 2025 17:10:15 -0800 Subject: [PATCH] refac: connections handling --- backend/open_webui/routers/ollama.py | 105 ++++++++++++------ backend/open_webui/routers/openai.py | 35 ++++-- .../admin/Settings/Connections.svelte | 29 +++-- 3 files changed, 113 insertions(+), 56 deletions(-) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 275146c72..0c50adebd 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -152,10 +152,12 @@ async def send_post_request( ) -def get_api_key(url, configs): +def get_api_key(idx, url, configs): parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - return configs.get(base_url, {}).get("key", None) + return configs.get(idx, configs.get(base_url, {})).get( + "key", None + ) # Legacy support ########################################## @@ -238,11 +240,13 @@ async def update_config( request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS - # Remove any extra configs - config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() - for url in list(request.app.state.config.OLLAMA_BASE_URLS): - if url not in config_urls: - request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) + # Remove the API configs that are not in the API URLS + keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS)))) + request.app.state.config.OLLAMA_API_CONFIGS = { + key: value + for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items() + if key in keys + } return { "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, @@ -258,10 +262,18 @@ async def get_all_models(request: Request): request_tasks = [] for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): - if url not in request.app.state.config.OLLAMA_API_CONFIGS: + if (idx not in request.app.state.config.OLLAMA_API_CONFIGS) or ( + url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support + ): request_tasks.append(send_get_request(f"{url}/api/tags")) else: - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + idx, + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + enable = api_config.get("enable", True) key = api_config.get("key", None) @@ -275,7 +287,12 @@ async def get_all_models(request: Request): for idx, response in enumerate(responses): if response: url = request.app.state.config.OLLAMA_BASE_URLS[idx] - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + idx, + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) model_ids = api_config.get("model_ids", []) @@ -349,7 +366,7 @@ async def get_ollama_tags( models = await get_all_models(request) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) r = None try: @@ -393,11 +410,14 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): request_tasks = [ send_get_request( f"{url}/api/version", - request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( - "key", None - ), + request.app.state.config.OLLAMA_API_CONFIGS.get( + idx, + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ).get("key", None), ) - for url in request.app.state.config.OLLAMA_BASE_URLS + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] responses = await asyncio.gather(*request_tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -454,11 +474,14 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u request_tasks = [ send_get_request( f"{url}/api/ps", - request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( - "key", None - ), + request.app.state.config.OLLAMA_API_CONFIGS.get( + idx, + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ).get("key", None), ) - for url in request.app.state.config.OLLAMA_BASE_URLS + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] responses = await asyncio.gather(*request_tasks) @@ -488,7 +511,7 @@ async def pull_model( return await send_post_request( url=f"{url}/api/pull", payload=json.dumps(payload), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -524,7 +547,7 @@ async def push_model( return await send_post_request( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -549,7 +572,7 @@ async def create_model( return await send_post_request( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -579,7 +602,7 @@ async def copy_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -634,7 +657,7 @@ async def delete_model( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -684,7 +707,7 @@ async def show_model_info( url_idx = random.choice(models[form_data.name]["urls"]) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -753,7 +776,7 @@ async def embed( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -822,7 +845,7 @@ async def embeddings( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) + key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) try: r = requests.request( @@ -897,7 +920,10 @@ async def generate_completion( ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + url_idx, + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -906,7 +932,7 @@ async def generate_completion( return await send_post_request( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -1005,7 +1031,10 @@ async def generate_chat_completion( payload["model"] = f"{payload['model']}:latest" url = await get_ollama_url(request, payload["model"], url_idx) - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + url_idx, + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -1015,7 +1044,7 @@ async def generate_chat_completion( url=f"{url}/api/chat", payload=json.dumps(payload), stream=form_data.stream, - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", ) @@ -1104,7 +1133,10 @@ async def generate_openai_completion( payload["model"] = f"{payload['model']}:latest" url = await get_ollama_url(request, payload["model"], url_idx) - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + url_idx, + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) @@ -1115,7 +1147,7 @@ async def generate_openai_completion( url=f"{url}/v1/completions", payload=json.dumps(payload), stream=payload.get("stream", False), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -1178,7 +1210,10 @@ async def generate_openai_chat_completion( payload["model"] = f"{payload['model']}:latest" url = await get_ollama_url(request, payload["model"], url_idx) - api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + url_idx, + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) if prefix_id: @@ -1188,7 +1223,7 @@ async def generate_openai_chat_completion( url=f"{url}/v1/chat/completions", payload=json.dumps(payload), stream=payload.get("stream", False), - key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), ) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 03f096a61..9f6e29913 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -145,11 +145,13 @@ async def update_config( request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS - # Remove any extra configs - config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys() - for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): - if url not in config_urls: - request.app.state.config.OPENAI_API_CONFIGS.pop(url, None) + # Remove the API configs that are not in the API URLS + keys = list(map(str, range(len(request.app.state.config.OPENAI_API_BASE_URLS)))) + request.app.state.config.OPENAI_API_CONFIGS = { + key: value + for key, value in request.app.state.config.OPENAI_API_CONFIGS.items() + if key in keys + } return { "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, @@ -264,14 +266,21 @@ async def get_all_models_responses(request: Request) -> list: request_tasks = [] for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): - if url not in request.app.state.config.OPENAI_API_CONFIGS: + if (idx not in request.app.state.config.OPENAI_API_CONFIGS) or ( + url not in request.app.state.config.OPENAI_API_CONFIGS # Legacy support + ): request_tasks.append( send_get_request( f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] ) ) else: - api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + idx, + request.app.state.config.OPENAI_API_CONFIGS.get( + url, {} + ), # Legacy support + ) enable = api_config.get("enable", True) model_ids = api_config.get("model_ids", []) @@ -310,7 +319,12 @@ async def get_all_models_responses(request: Request) -> list: for idx, response in enumerate(responses): if response: url = request.app.state.config.OPENAI_API_BASE_URLS[idx] - api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + idx, + request.app.state.config.OPENAI_API_CONFIGS.get( + url, {} + ), # Legacy support + ) prefix_id = api_config.get("prefix_id", None) @@ -585,7 +599,10 @@ async def generate_chat_completion( # Get the API config for the model api_config = request.app.state.config.OPENAI_API_CONFIGS.get( - request.app.state.config.OPENAI_API_BASE_URLS[idx], {} + idx, + request.app.state.config.OPENAI_API_CONFIGS.get( + request.app.state.config.OPENAI_API_BASE_URLS[idx], {} + ), # Legacy support ) prefix_id = api_config.get("prefix_id", None) diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index e9d2baecd..c4dffed0a 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -104,14 +104,14 @@ const addOpenAIConnectionHandler = async (connection) => { OPENAI_API_BASE_URLS = [...OPENAI_API_BASE_URLS, connection.url]; OPENAI_API_KEYS = [...OPENAI_API_KEYS, connection.key]; - OPENAI_API_CONFIGS[connection.url] = connection.config; + OPENAI_API_CONFIGS[OPENAI_API_BASE_URLS.length] = connection.config; await updateOpenAIHandler(); }; const addOllamaConnectionHandler = async (connection) => { OLLAMA_BASE_URLS = [...OLLAMA_BASE_URLS, connection.url]; - OLLAMA_API_CONFIGS[connection.url] = connection.config; + OLLAMA_API_CONFIGS[OLLAMA_BASE_URLS.length] = connection.config; await updateOllamaHandler(); }; @@ -141,15 +141,17 @@ OLLAMA_API_CONFIGS = ollamaConfig.OLLAMA_API_CONFIGS; if (ENABLE_OPENAI_API) { - for (const url of OPENAI_API_BASE_URLS) { - if (!OPENAI_API_CONFIGS[url]) { - OPENAI_API_CONFIGS[url] = {}; + // get url and idx + for (const [idx, url] of OPENAI_API_BASE_URLS.entries()) { + if (!OPENAI_API_CONFIGS[idx]) { + // Legacy support, url as key + OPENAI_API_CONFIGS[idx] = OPENAI_API_CONFIGS[url] || {}; } } OPENAI_API_BASE_URLS.forEach(async (url, idx) => { - OPENAI_API_CONFIGS[url] = OPENAI_API_CONFIGS[url] || {}; - if (!(OPENAI_API_CONFIGS[url]?.enable ?? true)) { + OPENAI_API_CONFIGS[idx] = OPENAI_API_CONFIGS[idx] || {}; + if (!(OPENAI_API_CONFIGS[idx]?.enable ?? true)) { return; } const res = await getOpenAIModels(localStorage.token, idx); @@ -160,9 +162,9 @@ } if (ENABLE_OLLAMA_API) { - for (const url of OLLAMA_BASE_URLS) { - if (!OLLAMA_API_CONFIGS[url]) { - OLLAMA_API_CONFIGS[url] = {}; + for (const [idx, url] of OLLAMA_BASE_URLS.entries()) { + if (!OLLAMA_API_CONFIGS[idx]) { + OLLAMA_API_CONFIGS[idx] = OLLAMA_API_CONFIGS[url] || {}; } } } @@ -235,7 +237,7 @@ pipeline={pipelineUrls[url] ? true : false} bind:url bind:key={OPENAI_API_KEYS[idx]} - bind:config={OPENAI_API_CONFIGS[url]} + bind:config={OPENAI_API_CONFIGS[idx]} onSubmit={() => { updateOpenAIHandler(); }} @@ -244,6 +246,8 @@ (url, urlIdx) => idx !== urlIdx ); OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx); + + delete OPENAI_API_CONFIGS[idx]; }} /> {/each} @@ -294,13 +298,14 @@ {#each OLLAMA_BASE_URLS as url, idx} { updateOllamaHandler(); }} onDelete={() => { OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx); + delete OLLAMA_API_CONFIGS[idx]; }} /> {/each}