From c30941298030ed0126fbd9f027654b1effbc54f3 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 11 Mar 2025 20:37:30 +0000 Subject: [PATCH] enh: connection tags --- backend/open_webui/main.py | 24 +++++++---- backend/open_webui/routers/ollama.py | 7 +++- backend/open_webui/routers/openai.py | 9 +++- backend/open_webui/utils/models.py | 1 + src/lib/apis/index.ts | 7 ++++ src/lib/components/AddConnectionModal.svelte | 28 +++++++++++++ .../chat/ModelSelector/Selector.svelte | 41 ++++++++++--------- 7 files changed, 88 insertions(+), 29 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ab5e98880..1ea79aa26 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -965,14 +965,24 @@ async def get_models(request: Request, user=Depends(get_verified_user)): return filtered_models - models = await get_all_models(request, user=user) + all_models = await get_all_models(request, user=user) - # Filter out filter pipelines - models = [ - model - for model in models - if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" - ] + models = [] + for model in all_models: + # Filter out filter pipelines + if "pipeline" in model and model["pipeline"].get("type", None) == "filter": + continue + + model_tags = [ + tag.get("name") + for tag in model.get("info", {}).get("meta", {}).get("tags", []) + ] + tags = [tag.get("name") for tag in model.get("tags", [])] + + tags = list(set(model_tags + tags)) + model["tags"] = [{"name": tag} for tag in tags] + + models.append(model) model_order_list = request.app.state.config.MODEL_ORDER_LIST if model_order_list: diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 959b8417a..35dc50431 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -295,7 +295,7 @@ async def update_config( } -@cached(ttl=3) +@cached(ttl=1) async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: @@ -336,6 +336,7 @@ async def get_all_models(request: Request, user: UserModel = None): ) prefix_id = api_config.get("prefix_id", None) + tags = api_config.get("tags", []) model_ids = api_config.get("model_ids", []) if len(model_ids) != 0 and "models" in response: @@ -350,6 +351,10 @@ async def get_all_models(request: Request, user: UserModel = None): for model in response.get("models", []): model["model"] = f"{prefix_id}.{model['model']}" + if tags: + for model in response.get("models", []): + model["tags"] = tags + def merge_models_lists(model_lists): merged_models = {} diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index bef286ca9..0310014cf 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -353,6 +353,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: ) prefix_id = api_config.get("prefix_id", None) + tags = api_config.get("tags", []) if prefix_id: for model in ( @@ -360,6 +361,12 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: ): model["id"] = f"{prefix_id}.{model['id']}" + if tags: + for model in ( + response if isinstance(response, list) else response.get("data", []) + ): + model["tags"] = tags + log.debug(f"get_all_models:responses() {responses}") return responses @@ -377,7 +384,7 @@ async def get_filtered_models(models, user): return filtered_models -@cached(ttl=3) +@cached(ttl=1) async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 149e41a41..b631c2ae3 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None): "created": int(time.time()), "owned_by": "ollama", "ollama": model, + "tags": model.get("tags", []), } for model in ollama_models["models"] ] diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 3fb4a5d01..674f24267 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -114,6 +114,13 @@ export const getModels = async ( } } + const tags = apiConfig.tags; + if (tags) { + for (const model of models) { + model.tags = tags; + } + } + localModels = localModels.concat(models); } } diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index f3132640a..7a82f340c 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -14,6 +14,7 @@ import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; + import Tags from './common/Tags.svelte'; export let onSubmit: Function = () => {}; export let onDelete: Function = () => {}; @@ -31,6 +32,7 @@ let prefixId = ''; let enable = true; + let tags = []; let modelId = ''; let modelIds = []; @@ -88,6 +90,7 @@ key, config: { enable: enable, + tags: tags, prefix_id: prefixId, model_ids: modelIds } @@ -101,6 +104,7 @@ url = ''; key = ''; prefixId = ''; + tags = []; modelIds = []; }; @@ -110,6 +114,7 @@ key = connection.key; enable = connection.config?.enable ?? true; + tags = connection.config?.tags ?? []; prefixId = connection.config?.prefix_id ?? ''; modelIds = connection.config?.model_ids ?? []; } @@ -244,6 +249,29 @@ +
+
+
{$i18n.t('Tags')}
+ +
+ { + tags = [ + ...tags, + { + name: e.detail + } + ]; + }} + on:delete={(e) => { + tags = tags.filter((tag) => tag.name !== e.detail); + }} + /> +
+
+
+
diff --git a/src/lib/components/chat/ModelSelector/Selector.svelte b/src/lib/components/chat/ModelSelector/Selector.svelte index 4db476159..898ea19d7 100644 --- a/src/lib/components/chat/ModelSelector/Selector.svelte +++ b/src/lib/components/chat/ModelSelector/Selector.svelte @@ -77,7 +77,7 @@ const _item = { ...item, modelName: item.model?.name, - tags: item.model?.info?.meta?.tags?.map((tag) => tag.name).join(' '), + tags: (item.model?.tags ?? []).map((tag) => tag.name).join(' '), desc: item.model?.info?.meta?.description }; return _item; @@ -98,7 +98,7 @@ if (selectedTag === '') { return true; } - return item.model?.info?.meta?.tags?.map((tag) => tag.name).includes(selectedTag); + return (item.model?.tags ?? []).map((tag) => tag.name).includes(selectedTag); }) .filter((item) => { if (selectedConnectionType === '') { @@ -116,7 +116,7 @@ if (selectedTag === '') { return true; } - return item.model?.info?.meta?.tags?.map((tag) => tag.name).includes(selectedTag); + return (item.model?.tags ?? []).map((tag) => tag.name).includes(selectedTag); }) .filter((item) => { if (selectedConnectionType === '') { @@ -262,7 +262,7 @@ ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false); if (items) { - tags = items.flatMap((item) => item.model?.info?.meta?.tags ?? []).map((tag) => tag.name); + tags = items.flatMap((item) => item.model?.tags ?? []).map((tag) => tag.name); // Remove duplicates and sort tags = Array.from(new Set(tags)).sort((a, b) => a.localeCompare(b)); @@ -291,12 +291,12 @@ onOpenChange={async () => { searchValue = ''; // Do NOT reset filters - keep the previously selected tag/connection type - + await tick(); - + // First check if the currently selected model is visible in the filtered list - const selectedInFiltered = filteredItems.findIndex(item => item.value === value); - + const selectedInFiltered = filteredItems.findIndex((item) => item.value === value); + if (selectedInFiltered >= 0) { // The selected model is visible in the current filter selectedModelIdx = selectedInFiltered; @@ -304,22 +304,23 @@ // The selected model is not visible, default to first item in filtered list selectedModelIdx = 0; } - + await tick(); - + // Scroll to the selected item if it exists in the current filtered view - const itemToScrollTo = selectedInFiltered >= 0 - ? document.querySelector(`[data-value="${value}"]`) - : document.querySelector('[data-arrow-selected="true"]'); - + const itemToScrollTo = + selectedInFiltered >= 0 + ? document.querySelector(`[data-value="${value}"]`) + : document.querySelector('[data-arrow-selected="true"]'); + if (itemToScrollTo) { const container = itemToScrollTo.closest('.overflow-y-auto'); if (container) { const itemTop = itemToScrollTo.offsetTop; const containerHeight = container.clientHeight; const itemHeight = itemToScrollTo.clientHeight; - - container.scrollTop = itemTop - (containerHeight / 2) + (itemHeight / 2); + + container.scrollTop = itemTop - containerHeight / 2 + itemHeight / 2; } } }} @@ -483,9 +484,9 @@ }} >
- {#if $mobile && (item?.model?.info?.meta?.tags ?? []).length > 0} + {#if $mobile && (item?.model?.tags ?? []).length > 0}
- {#each item.model?.info?.meta.tags as tag} + {#each item.model?.tags as tag}
@@ -605,11 +606,11 @@ {/if} - {#if !$mobile && (item?.model?.info?.meta?.tags ?? []).length > 0} + {#if !$mobile && (item?.model?.tags ?? []).length > 0}
- {#each item.model?.info?.meta.tags as tag} + {#each item.model?.tags as tag}