From 1cf21d3fa219edefd1599deb11267b36a0be422a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 23 May 2025 19:45:29 +0400 Subject: [PATCH] feat: ollama unload model --- backend/open_webui/routers/ollama.py | 64 +++++++++++++++++++ src/lib/apis/ollama/index.ts | 25 ++++++++ .../chat/ModelSelector/Selector.svelte | 44 +++++++++++-- 3 files changed, 127 insertions(+), 6 deletions(-) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index da2e8b388..0d5e2c243 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -623,6 +623,70 @@ class ModelNameForm(BaseModel): name: str +@router.post("/api/unload") +async def unload_model( + request: Request, + form_data: ModelNameForm, + user=Depends(get_admin_user), +): + model_name = form_data.name + if not model_name: + raise HTTPException( + status_code=400, detail="Missing 'name' of model to unload." + ) + + # Refresh/load models if needed, get mapping from name to URLs + await get_all_models(request, user=user) + models = request.app.state.OLLAMA_MODELS + + # Canonicalize model name (if not supplied with version) + if ":" not in model_name: + model_name = f"{model_name}:latest" + + if model_name not in models: + raise HTTPException( + status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name) + ) + url_indices = models[model_name]["urls"] + + # Send unload to ALL url_indices + results = [] + errors = [] + for idx in url_indices: + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + ) + key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS) + + prefix_id = api_config.get("prefix_id", None) + if prefix_id and model_name.startswith(f"{prefix_id}."): + model_name = model_name[len(f"{prefix_id}.") :] + + payload = {"model": model_name, "keep_alive": 0, "prompt": ""} + + try: + res = await send_post_request( + url=f"{url}/api/generate", + payload=json.dumps(payload), + stream=False, + key=key, + user=user, + ) + results.append({"url_idx": idx, "success": True, "response": res}) + except Exception as e: + log.exception(f"Failed to unload model on node {idx}: {e}") + errors.append({"url_idx": idx, "success": False, "error": str(e)}) + + if len(errors) > 0: + raise HTTPException( + status_code=500, + detail=f"Failed to unload model on {len(errors)} nodes: {errors}", + ) + + return {"status": True} + + @router.post("/api/pull") @router.post("/api/pull/{url_idx}") async def pull_model( diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index f159555da..489055c1b 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -355,6 +355,31 @@ export const generateChatCompletion = async (token: string = '', body: object) = return [res, controller]; }; +export const unloadModel = async (token: string, tagName: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/unload`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + name: tagName + }) + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const createModel = async (token: string, payload: object, urlIdx: string | null = null) => { let error = null; diff --git a/src/lib/components/chat/ModelSelector/Selector.svelte b/src/lib/components/chat/ModelSelector/Selector.svelte index 968619c31..86504cceb 100644 --- a/src/lib/components/chat/ModelSelector/Selector.svelte +++ b/src/lib/components/chat/ModelSelector/Selector.svelte @@ -10,7 +10,7 @@ import Check from '$lib/components/icons/Check.svelte'; import Search from '$lib/components/icons/Search.svelte'; - import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama'; + import { deleteModel, getOllamaVersion, pullModel, unloadModel } from '$lib/apis/ollama'; import { user, @@ -31,6 +31,7 @@ import { goto } from '$app/navigation'; import dayjs from '$lib/dayjs'; import relativeTime from 'dayjs/plugin/relativeTime'; + import ArrowUpTray from '$lib/components/icons/ArrowUpTray.svelte'; dayjs.extend(relativeTime); const i18n = getContext('i18n'); @@ -312,6 +313,22 @@ toast.success(`${model} download has been canceled`); } }; + + const unloadModelHandler = async (model: string) => { + const res = await unloadModel(localStorage.token, model).catch((error) => { + toast.error($i18n.t('Error unloading model: {{error}}', { error })); + }); + + if (res) { + toast.success($i18n.t('Model unloaded successfully')); + models.set( + await getModels( + localStorage.token, + $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) + ) + ); + } + }; - {#if value === item.value} -
- -
- {/if} +
+ {#if $user?.role === 'admin' && item.model.owned_by === 'ollama' && item.model.ollama?.expires_at && new Date(item.model.ollama?.expires_at * 1000) > new Date()} + + + + {/if} + + {#if value === item.value} +
+ +
+ {/if} +
{:else}