diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 760174837..6ac797d2d 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -84,6 +84,7 @@ log_sources = [ "COMFYUI", "CONFIG", "DB", + "DMR", "IMAGES", "MAIN", "MODELS", diff --git a/backend/open_webui/routers/docker_model_runner.py b/backend/open_webui/routers/docker_model_runner.py index 8f24ea180..c73ba411f 100644 --- a/backend/open_webui/routers/docker_model_runner.py +++ b/backend/open_webui/routers/docker_model_runner.py @@ -1,27 +1,34 @@ import logging -import aiohttp -from typing import Union -from urllib.parse import urlparse import time +from typing import Optional, Union +import aiohttp +from aiocache import cached from fastapi import APIRouter, Depends, Request, HTTPException from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict from starlette.background import BackgroundTask -from pydantic import BaseModel -from open_webui.models.users import UserModel -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( - AIOHTTP_CLIENT_SESSION_SSL, + AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, + BYPASS_MODEL_ACCESS_CONTROL, ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS ) +from open_webui.models.users import UserModel +from open_webui.models.models import Models +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) +from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS.get("DMR", logging.INFO)) +log.setLevel(SRC_LOG_LEVELS["DMR"]) router = APIRouter() @@ -35,7 +42,7 @@ DMR_ENGINE_SUFFIX = "/engines/llama.cpp/v1" ########################################## async def send_get_request(url, user: UserModel = None): - """Send GET request to DMR backend with proper error handling""" + """Send a GET request to DMR backend with proper error handling""" timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: @@ -77,6 +84,7 @@ async def send_post_request( url: str, payload: Union[str, bytes, dict], stream: bool = False, + content_type: Optional[str] = None, user: UserModel = None, ): """Send POST request to DMR backend with proper error handling""" @@ -90,7 +98,7 @@ async def send_post_request( url, data=payload if isinstance(payload, (str, bytes)) else aiohttp.JsonPayload(payload), headers={ - "Content-Type": "application/json", + "Content-Type": content_type or "application/json", **( { "X-OpenWebUI-User-Name": user.name, @@ -107,10 +115,14 @@ async def send_post_request( r.raise_for_status() if stream: + response_headers = dict(r.headers) + if content_type: + response_headers["Content-Type"] = content_type + return StreamingResponse( r.content, status_code=r.status, - headers=dict(r.headers), + headers=response_headers, background=BackgroundTask(cleanup_response, response=r, session=session), ) else: @@ -139,7 +151,7 @@ def get_dmr_base_url(request: Request): """Get DMR base URL with engine suffix""" base_url = request.app.state.config.DMR_BASE_URL if not base_url: - raise HTTPException(status_code=500, detail="No DMR base URL configured") + raise HTTPException(status_code=500, detail="DMR base URL not configured") # Always append the engine prefix for OpenAI-compatible endpoints if not base_url.rstrip("/").endswith(DMR_ENGINE_SUFFIX): @@ -147,181 +159,17 @@ def get_dmr_base_url(request: Request): return base_url -########################################## -# -# API routes -# -########################################## - -@router.get("/config") -async def get_config(request: Request, user=Depends(get_admin_user)): - """Get DMR configuration""" - return { - "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, - "DMR_BASE_URL": request.app.state.config.DMR_BASE_URL, - "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, - } - - -class DMRConfigForm(BaseModel): - ENABLE_DMR_API: Optional[bool] = None - DMR_BASE_URL: str - DMR_API_CONFIGS: dict = {} - - -@router.post("/config/update") -async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)): - """Update DMR configuration""" - request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API - request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL - request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS - - return { - "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, - "DMR_BASE_URL": request.app.state.config.DMR_BASE_URL, - "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, - } - - -class ConnectionVerificationForm(BaseModel): - url: str - - -@router.post("/verify") -async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)): - """Verify connection to DMR backend""" - url = form_data.url - - # Append engine suffix if not present - if not url.rstrip("/").endswith(DMR_ENGINE_SUFFIX): - url = url.rstrip("/") + DMR_ENGINE_SUFFIX - - try: - response = await send_get_request(f"{url}/models", user=user) - if response is not None: - return {"status": "success", "message": "Connection verified"} - else: - raise HTTPException(status_code=400, detail="Failed to connect to DMR backend") - except Exception as e: - log.exception(f"DMR connection verification failed: {e}") - raise HTTPException(status_code=400, detail=f"Connection verification failed: {e}") - - -@router.get("/models") -async def get_models(request: Request, user=Depends(get_verified_user)): - """Get available models from DMR backend""" - url = get_dmr_base_url(request) - - response = await send_get_request(f"{url}/models", user=user) - if response is None: - raise HTTPException(status_code=500, detail="Failed to fetch models from DMR backend") - return response - - -@router.post("/chat/completions") -async def generate_chat_completion( - request: Request, - form_data: dict, - user=Depends(get_verified_user), - bypass_filter: Optional[bool] = False, -): - """Generate chat completions using DMR backend""" - url = get_dmr_base_url(request) - - log.debug(f"DMR chat_completions: model = {form_data.get('model', 'NO_MODEL')}") - - # Resolve model ID if needed - if "model" in form_data: - models = await get_all_models(request, user=user) - for m in models.get("data", []): - if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]: - form_data["model"] = m["id"] - break - - return await send_post_request( - f"{url}/chat/completions", - form_data, - stream=form_data.get("stream", False), - user=user - ) - - -@router.post("/completions") -async def generate_completion( - request: Request, - form_data: dict, - user=Depends(get_verified_user), -): - """Generate completions using DMR backend""" - url = get_dmr_base_url(request) - - # Resolve model ID if needed - if "model" in form_data: - models = await get_all_models(request, user=user) - for m in models.get("data", []): - if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]: - form_data["model"] = m["id"] - break - - return await send_post_request( - f"{url}/completions", - form_data, - stream=form_data.get("stream", False), - user=user - ) - - -@router.post("/embeddings") -async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)): - """Generate embeddings using DMR backend""" - url = get_dmr_base_url(request) - - return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user) - - -# OpenAI-compatible endpoints -@router.get("/v1/models") -async def get_openai_models(request: Request, user=Depends(get_verified_user)): - """Get available models from DMR backend (OpenAI-compatible)""" - return await get_models(request, user) - - -@router.post("/v1/chat/completions") -async def generate_openai_chat_completion( - request: Request, - form_data: dict, - user=Depends(get_verified_user) -): - """Generate chat completions using DMR backend (OpenAI-compatible)""" - return await generate_chat_completion(request, form_data, user) - - -@router.post("/v1/completions") -async def generate_openai_completion( - request: Request, - form_data: dict, - user=Depends(get_verified_user) -): - """Generate completions using DMR backend (OpenAI-compatible)""" - return await generate_completion(request, form_data, user) - - -@router.post("/v1/embeddings") -async def generate_openai_embeddings( - request: Request, - form_data: dict, - user=Depends(get_verified_user) -): - """Generate embeddings using DMR backend (OpenAI-compatible)""" - return await embeddings(request, form_data, user) - - -# Internal utility for Open WebUI model aggregation +@cached(ttl=1) async def get_all_models(request: Request, user: UserModel = None): """ Fetch all models from the DMR backend in OpenAI-compatible format for internal use. Returns: dict with 'data' key (list of models) """ + log.info("get_all_models() - DMR") + + if not request.app.state.config.ENABLE_DMR_API: + return {"data": []} + try: url = get_dmr_base_url(request) @@ -399,3 +247,502 @@ async def get_all_models(request: Request, user: UserModel = None): except Exception as e: log.exception(f"DMR get_all_models failed: {e}") return {"data": []} + + +async def get_filtered_models(models, user): + """Filter models based on user access control""" + if BYPASS_MODEL_ACCESS_CONTROL: + return models.get("data", []) + + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + else: + # If no model info found and user is admin, include it + if user.role == "admin": + filtered_models.append(model) + return filtered_models + + +########################################## +# +# Configuration endpoints +# +########################################## + +@router.head("/") +@router.get("/") +async def get_status(): + return {"status": True} + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + """Get DMR configuration""" + return { + "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, + "DMR_BASE_URL": request.app.state.config.DMR_BASE_URL, + "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, + } + + +class DMRConfigForm(BaseModel): + ENABLE_DMR_API: Optional[bool] = None + DMR_BASE_URL: str + DMR_API_CONFIGS: dict = {} + + +@router.post("/config/update") +async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)): + """Update DMR configuration""" + request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API + request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL + request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS + + return { + "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, + "DMR_BASE_URL": request.app.state.config.DMR_BASE_URL, + "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, + } + + +class ConnectionVerificationForm(BaseModel): + url: str + + +@router.post("/verify") +async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)): + """Verify connection to DMR backend""" + url = form_data.url + + # Append engine suffix if not present + if not url.rstrip("/").endswith(DMR_ENGINE_SUFFIX): + url = url.rstrip("/") + DMR_ENGINE_SUFFIX + + try: + response = await send_get_request(f"{url}/models", user=user) + if response is not None: + return {"status": "success", "message": "Connection verified"} + else: + raise HTTPException(status_code=400, detail="Failed to connect to DMR backend") + except Exception as e: + log.exception(f"DMR connection verification failed: {e}") + raise HTTPException(status_code=400, detail=f"Connection verification failed: {e}") + + +########################################## +# +# Model endpoints +# +########################################## + +@router.get("/api/tags") +async def get_dmr_tags(request: Request, user=Depends(get_verified_user)): + """Get available models from DMR backend (Ollama-compatible format)""" + models = await get_all_models(request, user=user) + + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + filtered_models = await get_filtered_models(models, user) + models = {"data": filtered_models} + + # Convert to Ollama-compatible format + ollama_models = [] + for model in models.get("data", []): + ollama_model = { + "model": model["id"], + "name": model["name"], + "size": model.get("size", 0), + "digest": model.get("digest", ""), + "details": model.get("details", {}), + "expires_at": model.get("expires_at"), + "size_vram": model.get("size_vram", 0), + } + ollama_models.append(ollama_model) + + return {"models": ollama_models} + + +@router.get("/models") +async def get_models(request: Request, user=Depends(get_verified_user)): + """Get available models from DMR backend (OpenAI-compatible format)""" + models = await get_all_models(request, user=user) + + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + filtered_models = await get_filtered_models(models, user) + return {"data": filtered_models} + + return models + + +class ModelNameForm(BaseModel): + name: str + + +@router.post("/api/show") +async def show_model_info( + request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) +): + """Show model information (Ollama-compatible)""" + models = await get_all_models(request, user=user) + + # Find the model + model_found = None + for model in models.get("data", []): + if model["id"] == form_data.name or model["name"] == form_data.name: + model_found = model + break + + if not model_found: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + ) + + # Return model info in Ollama format + return { + "model": model_found["id"], + "details": model_found.get("dmr", {}), + "modelfile": "", # DMR doesn't provide modelfile + "parameters": "", # DMR doesn't provide parameters + "template": "", # DMR doesn't provide template + } + + +########################################## +# +# Generation endpoints +# +########################################## + +class GenerateChatCompletionForm(BaseModel): + model: str + messages: list[dict] + stream: Optional[bool] = False + model_config = ConfigDict(extra="allow") + + +@router.post("/api/chat") +async def generate_chat_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: Optional[bool] = False, +): + """Generate chat completions using DMR backend (Ollama-compatible)""" + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + metadata = form_data.pop("metadata", None) + + try: + completion_form = GenerateChatCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail=str(e)) + + payload = {**completion_form.model_dump(exclude_none=True)} + if "metadata" in payload: + del payload["metadata"] + + model_id = payload["model"] + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + if params: + system = params.pop("system", None) + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(system, payload, metadata, user) + + # Check if user has access to the model + if not bypass_filter and user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException(status_code=403, detail="Model not found") + elif not bypass_filter: + if user.role != "admin": + raise HTTPException(status_code=403, detail="Model not found") + + url = get_dmr_base_url(request) + + log.debug(f"DMR chat completion: model = {payload.get('model', 'NO_MODEL')}") + + return await send_post_request( + f"{url}/chat/completions", + payload, + stream=payload.get("stream", False), + content_type="application/x-ndjson" if payload.get("stream") else None, + user=user, + ) + + +class GenerateCompletionForm(BaseModel): + model: str + prompt: str + stream: Optional[bool] = False + model_config = ConfigDict(extra="allow") + + +@router.post("/api/generate") +async def generate_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user), +): + """Generate completions using DMR backend (Ollama-compatible)""" + try: + completion_form = GenerateCompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail=str(e)) + + payload = {**completion_form.model_dump(exclude_none=True)} + + model_id = payload["model"] + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + if params: + payload = apply_model_params_to_body_openai(params, payload) + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException(status_code=403, detail="Model not found") + else: + if user.role != "admin": + raise HTTPException(status_code=403, detail="Model not found") + + url = get_dmr_base_url(request) + + return await send_post_request( + f"{url}/completions", + payload, + stream=payload.get("stream", False), + user=user, + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + input: Union[str, list[str]] + model_config = ConfigDict(extra="allow") + + +@router.post("/api/embeddings") +async def embeddings( + request: Request, + form_data: dict, + user=Depends(get_verified_user) +): + """Generate embeddings using DMR backend (Ollama-compatible)""" + try: + embedding_form = GenerateEmbeddingsForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=400, detail=str(e)) + + payload = {**embedding_form.model_dump(exclude_none=True)} + + model_id = payload["model"] + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException(status_code=403, detail="Model not found") + else: + if user.role != "admin": + raise HTTPException(status_code=403, detail="Model not found") + + url = get_dmr_base_url(request) + + return await send_post_request(f"{url}/embeddings", payload, stream=False, user=user) + + +########################################## +# +# OpenAI-compatible endpoints +# +########################################## + +@router.get("/v1/models") +async def get_openai_models(request: Request, user=Depends(get_verified_user)): + """Get available models from DMR backend (OpenAI-compatible)""" + models = await get_all_models(request, user=user) + + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + filtered_models = await get_filtered_models(models, user) + models = {"data": filtered_models} + + # Convert to OpenAI format + openai_models = [] + for model in models.get("data", []): + openai_model = { + "id": model["id"], + "object": "model", + "created": model.get("created", int(time.time())), + "owned_by": "docker", + } + openai_models.append(openai_model) + + return { + "object": "list", + "data": openai_models, + } + + +@router.post("/v1/chat/completions") +async def generate_openai_chat_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user) +): + """Generate chat completions using DMR backend (OpenAI-compatible)""" + metadata = form_data.pop("metadata", None) + + payload = {**form_data} + if "metadata" in payload: + del payload["metadata"] + + model_id = payload["model"] + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + if params: + system = params.pop("system", None) + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(system, payload, metadata, user) + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException(status_code=403, detail="Model not found") + else: + if user.role != "admin": + raise HTTPException(status_code=403, detail="Model not found") + + url = get_dmr_base_url(request) + + return await send_post_request( + f"{url}/chat/completions", + payload, + stream=payload.get("stream", False), + user=user, + ) + + +@router.post("/v1/completions") +async def generate_openai_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user) +): + """Generate completions using DMR backend (OpenAI-compatible)""" + payload = {**form_data} + + model_id = payload["model"] + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + params = model_info.params.model_dump() + if params: + payload = apply_model_params_to_body_openai(params, payload) + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException(status_code=403, detail="Model not found") + else: + if user.role != "admin": + raise HTTPException(status_code=403, detail="Model not found") + + url = get_dmr_base_url(request) + + return await send_post_request( + f"{url}/completions", + payload, + stream=payload.get("stream", False), + user=user, + ) + + +@router.post("/v1/embeddings") +async def generate_openai_embeddings( + request: Request, + form_data: dict, + user=Depends(get_verified_user) +): + """Generate embeddings using DMR backend (OpenAI-compatible)""" + payload = {**form_data} + + model_id = payload["model"] + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + # Check if user has access to the model + if user.role == "user": + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException(status_code=403, detail="Model not found") + else: + if user.role != "admin": + raise HTTPException(status_code=403, detail="Model not found") + + url = get_dmr_base_url(request) + + return await send_post_request(f"{url}/embeddings", payload, stream=False, user=user)