From 13247493ef522337857d4d4987ba1988ffeab4c5 Mon Sep 17 00:00:00 2001 From: cvaz1306 Date: Sat, 28 Dec 2024 13:03:58 -0800 Subject: [PATCH] (fullstack) Adding model loaded indicator to Selector.svelte --- backend/open_webui/functions.py | 2 ++ backend/open_webui/routers/ollama.py | 14 ++++++++++++-- backend/open_webui/routers/openai.py | 1 + backend/open_webui/utils/models.py | 1 + .../chat/ModelSelector/Selector.svelte | 18 ++++++++++++++---- 5 files changed, 30 insertions(+), 6 deletions(-) diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 16536a612..19d0b83c0 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -108,6 +108,7 @@ async def get_function_models(request): "created": pipe.created_at, "owned_by": "openai", "pipe": pipe_flag, + "loaded": True, } ) else: @@ -125,6 +126,7 @@ async def get_function_models(request): "created": pipe.created_at, "owned_by": "openai", "pipe": pipe_flag, + "loaded": True, } ) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 275146c72..f215f354c 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -272,6 +272,13 @@ async def get_all_models(request: Request): responses = await asyncio.gather(*request_tasks) + loaded_models_response = requests.get( + f"{request.app.state.config.OLLAMA_BASE_URLS[0]}/api/ps" + ) + loaded_models = [ + model["model"] for model in loaded_models_response.json().get("models", []) + ] + for idx, response in enumerate(responses): if response: url = request.app.state.config.OLLAMA_BASE_URLS[idx] @@ -291,10 +298,11 @@ async def get_all_models(request: Request): if prefix_id: for model in response.get("models", []): model["model"] = f"{prefix_id}.{model['model']}" + for model in response.get("models", []): + model["loaded"] = model["model"] in loaded_models def merge_models_lists(model_lists): merged_models = {} - for idx, model_list in enumerate(model_lists): if model_list is not None: for model in model_list: @@ -304,7 +312,9 @@ async def get_all_models(request: Request): merged_models[id] = model else: merged_models[id]["urls"].append(idx) - + merged_models[id]["loaded"] = merged_models[id].get( + "loaded", False + ) or model.get("loaded", False) return list(merged_models.values()) models = { diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 4ab381ea4..415de5f1e 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -367,6 +367,7 @@ async def get_all_models(request: Request) -> dict[str, list]: "owned_by": "openai", "openai": model, "urlIdx": idx, + "loaded": True, } for model in models if "api.openai.com" diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 975f8cb09..3f9c711d5 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -48,6 +48,7 @@ async def get_all_base_models(request: Request): "created": int(time.time()), "owned_by": "ollama", "ollama": model, + "loaded": model.get("loaded", False), } for model in ollama_models["models"] ] diff --git a/src/lib/components/chat/ModelSelector/Selector.svelte b/src/lib/components/chat/ModelSelector/Selector.svelte index 9e6b1d0fe..454aa45fe 100644 --- a/src/lib/components/chat/ModelSelector/Selector.svelte +++ b/src/lib/components/chat/ModelSelector/Selector.svelte @@ -45,6 +45,9 @@ export let triggerClassName = 'text-lg'; let show = false; + setInterval(async () => { + if (show) models.set(await getModels(localStorage.token)); + }, 5000); let selectedModel = ''; $: selectedModel = items.find((item) => item.value === value) ?? ''; @@ -226,6 +229,7 @@ searchValue = ''; selectedModelIdx = 0; window.setTimeout(() => document.getElementById('model-search-input')?.focus(), 0); + models.set(await getModels(localStorage.token)); }} closeFocus={false} > @@ -292,10 +296,16 @@ {#each filteredItems as item, index}