diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ab5e98880..1ea79aa26 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -965,14 +965,24 @@ async def get_models(request: Request, user=Depends(get_verified_user)): return filtered_models - models = await get_all_models(request, user=user) + all_models = await get_all_models(request, user=user) - # Filter out filter pipelines - models = [ - model - for model in models - if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" - ] + models = [] + for model in all_models: + # Filter out filter pipelines + if "pipeline" in model and model["pipeline"].get("type", None) == "filter": + continue + + model_tags = [ + tag.get("name") + for tag in model.get("info", {}).get("meta", {}).get("tags", []) + ] + tags = [tag.get("name") for tag in model.get("tags", [])] + + tags = list(set(model_tags + tags)) + model["tags"] = [{"name": tag} for tag in tags] + + models.append(model) model_order_list = request.app.state.config.MODEL_ORDER_LIST if model_order_list: diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 959b8417a..35dc50431 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -295,7 +295,7 @@ async def update_config( } -@cached(ttl=3) +@cached(ttl=1) async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: @@ -336,6 +336,7 @@ async def get_all_models(request: Request, user: UserModel = None): ) prefix_id = api_config.get("prefix_id", None) + tags = api_config.get("tags", []) model_ids = api_config.get("model_ids", []) if len(model_ids) != 0 and "models" in response: @@ -350,6 +351,10 @@ async def get_all_models(request: Request, user: UserModel = None): for model in response.get("models", []): model["model"] = f"{prefix_id}.{model['model']}" + if tags: + for model in response.get("models", []): + model["tags"] = tags + def merge_models_lists(model_lists): merged_models = {} diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index bef286ca9..0310014cf 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -353,6 +353,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: ) prefix_id = api_config.get("prefix_id", None) + tags = api_config.get("tags", []) if prefix_id: for model in ( @@ -360,6 +361,12 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: ): model["id"] = f"{prefix_id}.{model['id']}" + if tags: + for model in ( + response if isinstance(response, list) else response.get("data", []) + ): + model["tags"] = tags + log.debug(f"get_all_models:responses() {responses}") return responses @@ -377,7 +384,7 @@ async def get_filtered_models(models, user): return filtered_models -@cached(ttl=3) +@cached(ttl=1) async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 149e41a41..b631c2ae3 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None): "created": int(time.time()), "owned_by": "ollama", "ollama": model, + "tags": model.get("tags", []), } for model in ollama_models["models"] ] diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 3fb4a5d01..674f24267 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -114,6 +114,13 @@ export const getModels = async ( } } + const tags = apiConfig.tags; + if (tags) { + for (const model of models) { + model.tags = tags; + } + } + localModels = localModels.concat(models); } } diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index f3132640a..7a82f340c 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -14,6 +14,7 @@ import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; + import Tags from './common/Tags.svelte'; export let onSubmit: Function = () => {}; export let onDelete: Function = () => {}; @@ -31,6 +32,7 @@ let prefixId = ''; let enable = true; + let tags = []; let modelId = ''; let modelIds = []; @@ -88,6 +90,7 @@ key, config: { enable: enable, + tags: tags, prefix_id: prefixId, model_ids: modelIds } @@ -101,6 +104,7 @@ url = ''; key = ''; prefixId = ''; + tags = []; modelIds = []; }; @@ -110,6 +114,7 @@ key = connection.key; enable = connection.config?.enable ?? true; + tags = connection.config?.tags ?? []; prefixId = connection.config?.prefix_id ?? ''; modelIds = connection.config?.model_ids ?? []; } @@ -244,6 +249,29 @@ +