From a50a8e2ef9befb6962b640eb1f3f4a205e6f35db Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 23 May 2025 18:47:50 +0400 Subject: [PATCH] refac: ollama ps --- backend/open_webui/routers/ollama.py | 125 ++++++++++++++++++--------- 1 file changed, 83 insertions(+), 42 deletions(-) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 7c313ea97..86b4b3f7d 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -300,6 +300,22 @@ async def update_config( } +def merge_ollama_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + if model_list is not None: + for model in model_list: + id = model["model"] + if id not in merged_models: + model["urls"] = [idx] + merged_models[id] = model + else: + merged_models[id]["urls"].append(idx) + + return list(merged_models.values()) + + @cached(ttl=1) async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") @@ -364,23 +380,8 @@ async def get_all_models(request: Request, user: UserModel = None): if connection_type: model["connection_type"] = connection_type - def merge_models_lists(model_lists): - merged_models = {} - - for idx, model_list in enumerate(model_lists): - if model_list is not None: - for model in model_list: - id = model["model"] - if id not in merged_models: - model["urls"] = [idx] - merged_models[id] = model - else: - merged_models[id]["urls"].append(idx) - - return list(merged_models.values()) - models = { - "models": merge_models_lists( + "models": merge_ollama_models_lists( map( lambda response: response.get("models", []) if response else None, responses, @@ -468,6 +469,72 @@ async def get_ollama_tags( return models +@router.get("/api/ps") +async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): + """ + List models that are currently loaded into Ollama memory, and which node they are loaded on. + """ + if request.app.state.config.ENABLE_OLLAMA_API: + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): + if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( + url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support + ): + request_tasks.append(send_get_request(f"{url}/api/ps", user=user)) + else: + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + + enable = api_config.get("enable", True) + key = api_config.get("key", None) + + if enable: + request_tasks.append( + send_get_request(f"{url}/api/ps", key, user=user) + ) + else: + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + + responses = await asyncio.gather(*request_tasks) + + for idx, response in enumerate(responses): + if response: + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + + prefix_id = api_config.get("prefix_id", None) + + for model in response.get("models", []): + if prefix_id: + model["model"] = f"{prefix_id}.{model['model']}" + + models = { + "models": merge_ollama_models_lists( + map( + lambda response: response.get("models", []) if response else None, + responses, + ) + ) + } + + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + models["models"] = await get_filtered_models(models, user) + + else: + models = {"models": []} + + return models + + @router.get("/api/version") @router.get("/api/version/{url_idx}") async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): @@ -541,32 +608,6 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): return {"version": False} -@router.get("/api/ps") -async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): - """ - List models that are currently loaded into Ollama memory, and which node they are loaded on. - """ - if request.app.state.config.ENABLE_OLLAMA_API: - request_tasks = [ - send_get_request( - f"{url}/api/ps", - request.app.state.config.OLLAMA_API_CONFIGS.get( - str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support - ).get("key", None), - user=user, - ) - for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) - ] - responses = await asyncio.gather(*request_tasks) - - return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) - else: - return {} - - class ModelNameForm(BaseModel): name: str