feat: ollama unload model

This commit is contained in:
Timothy Jaeryang Baek
2025-05-23 19:45:29 +04:00
parent 19b69fcb66
commit 1cf21d3fa2
3 changed files with 127 additions and 6 deletions

View File

@@ -623,6 +623,70 @@ class ModelNameForm(BaseModel):
name: str
@router.post("/api/unload")
async def unload_model(
request: Request,
form_data: ModelNameForm,
user=Depends(get_admin_user),
):
model_name = form_data.name
if not model_name:
raise HTTPException(
status_code=400, detail="Missing 'name' of model to unload."
)
# Refresh/load models if needed, get mapping from name to URLs
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
# Canonicalize model name (if not supplied with version)
if ":" not in model_name:
model_name = f"{model_name}:latest"
if model_name not in models:
raise HTTPException(
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
)
url_indices = models[model_name]["urls"]
# Send unload to ALL url_indices
results = []
errors = []
for idx in url_indices:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
)
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id and model_name.startswith(f"{prefix_id}."):
model_name = model_name[len(f"{prefix_id}.") :]
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
try:
res = await send_post_request(
url=f"{url}/api/generate",
payload=json.dumps(payload),
stream=False,
key=key,
user=user,
)
results.append({"url_idx": idx, "success": True, "response": res})
except Exception as e:
log.exception(f"Failed to unload model on node {idx}: {e}")
errors.append({"url_idx": idx, "success": False, "error": str(e)})
if len(errors) > 0:
raise HTTPException(
status_code=500,
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
)
return {"status": True}
@router.post("/api/pull")
@router.post("/api/pull/{url_idx}")
async def pull_model(