From 17c677a285a3fca0329a239ba40b326803ca2551 Mon Sep 17 00:00:00 2001 From: Sergei Shitikov Date: Mon, 16 Jun 2025 22:29:36 +0200 Subject: [PATCH] adding Docker Model Runner backend --- backend/open_webui/config.py | 9 +- backend/open_webui/main.py | 2 - .../open_webui/routers/docker_model_runner.py | 430 ++++++++++++++++-- backend/open_webui/utils/chat.py | 11 + docker-compose.yaml | 2 + 5 files changed, 401 insertions(+), 53 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index d72885f2d..89614c1ec 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -904,12 +904,13 @@ ENABLE_DMR_API = PersistentConfig( "dmr.enable", os.environ.get("ENABLE_DMR_API", "True").lower() == "true", ) - -DMR_API_KEYS = [k.strip() for k in os.environ.get("DMR_API_KEYS", "").split(";")] -DMR_API_KEYS = PersistentConfig("DMR_API_KEYS", "dmr.api_keys", DMR_API_KEYS) +DMR_API_BASE_URL = os.environ.get( + "DMR_API_BASE_URL", "http://localhost:12434" +) DMR_BASE_URL = os.environ.get("DMR_BASE_URL", "") if DMR_BASE_URL: + # Remove trailing slash DMR_BASE_URL = DMR_BASE_URL[:-1] if DMR_BASE_URL.endswith("/") else DMR_BASE_URL DMR_BASE_URLS = os.environ.get("DMR_BASE_URLS", "") @@ -1464,7 +1465,7 @@ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig( DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task: Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion. ### Guidelines: -- Write all follow-up questions from the user’s point of view, directed to the assistant. +- Write all follow-up questions from the user's point of view, directed to the assistant. - Make questions concise, clear, and directly related to the discussed topic(s). - Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered. - If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask. diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8f11d812e..7b4058c9a 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -116,7 +116,6 @@ from open_webui.config import ( # Docker Model Runner ENABLE_DMR_API, DMR_BASE_URLS, - DMR_API_KEYS, DMR_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, @@ -603,7 +602,6 @@ app.state.OPENAI_MODELS = {} app.state.config.ENABLE_DMR_API = ENABLE_DMR_API app.state.config.DMR_BASE_URLS = DMR_BASE_URLS -app.state.config.DMR_API_KEYS = DMR_API_KEYS app.state.config.DMR_API_CONFIGS = DMR_API_CONFIGS app.state.DMR_MODELS = {} diff --git a/backend/open_webui/routers/docker_model_runner.py b/backend/open_webui/routers/docker_model_runner.py index 7623cfb88..6113573e9 100644 --- a/backend/open_webui/routers/docker_model_runner.py +++ b/backend/open_webui/routers/docker_model_runner.py @@ -1,41 +1,173 @@ -from contextlib import contextmanager -from typing import Optional +import logging +import aiohttp +from typing import Optional, Union +from urllib.parse import urlparse +import time -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, Request, HTTPException +from fastapi.responses import StreamingResponse +from starlette.background import BackgroundTask from pydantic import BaseModel from open_webui.models.users import UserModel -from open_webui.routers import openai -from open_webui.routers.openai import ConnectionVerificationForm from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ( + AIOHTTP_CLIENT_SESSION_SSL, + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, + ENABLE_FORWARD_USER_INFO_HEADERS, + SRC_LOG_LEVELS +) + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS.get("DMR", logging.INFO)) router = APIRouter() -@contextmanager -def _dmr_context(request: Request): - orig_urls = request.app.state.config.OPENAI_API_BASE_URLS - orig_keys = request.app.state.config.OPENAI_API_KEYS - orig_configs = request.app.state.config.OPENAI_API_CONFIGS - orig_models = request.app.state.OPENAI_MODELS - request.app.state.config.OPENAI_API_BASE_URLS = request.app.state.config.DMR_BASE_URLS - request.app.state.config.OPENAI_API_KEYS = request.app.state.config.DMR_API_KEYS - request.app.state.config.OPENAI_API_CONFIGS = request.app.state.config.DMR_API_CONFIGS - request.app.state.OPENAI_MODELS = request.app.state.DMR_MODELS - try: - yield - finally: - request.app.state.config.OPENAI_API_BASE_URLS = orig_urls - request.app.state.config.OPENAI_API_KEYS = orig_keys - request.app.state.config.OPENAI_API_CONFIGS = orig_configs - request.app.state.OPENAI_MODELS = orig_models +# DMR-specific constants +DMR_ENGINE_SUFFIX = "/engines/llama.cpp/v1" +########################################## +# +# Utility functions +# +########################################## + +async def send_get_request(url, user: UserModel = None): + """Send GET request to DMR backend with proper error handling""" + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, + headers={ + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + return await response.json() + except Exception as e: + log.error(f"DMR connection error: {e}") + return None + + +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + """Clean up aiohttp resources""" + if response: + response.close() + if session: + await session.close() + + +async def send_post_request( + url: str, + payload: Union[str, bytes, dict], + stream: bool = False, + user: UserModel = None, +): + """Send POST request to DMR backend with proper error handling""" + r = None + try: + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) + + r = await session.post( + url, + data=payload if isinstance(payload, (str, bytes)) else aiohttp.JsonPayload(payload), + headers={ + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) + r.raise_for_status() + + if stream: + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask(cleanup_response, response=r, session=session), + ) + else: + res = await r.json() + await cleanup_response(r, session) + return res + + except Exception as e: + detail = None + if r is not None: + try: + res = await r.json() + if "error" in res: + detail = f"DMR: {res.get('error', 'Unknown error')}" + except Exception: + detail = f"DMR: {e}" + + await cleanup_response(r, session) + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: DMR Connection Error", + ) + + +def get_dmr_base_url(request: Request, url_idx: Optional[int] = None): + """Get DMR base URL with engine suffix""" + urls = request.app.state.config.DMR_BASE_URLS + if not urls: + raise HTTPException(status_code=500, detail="No DMR base URLs configured") + + if url_idx is None: + base = urls[0] + idx = 0 + else: + if url_idx >= len(urls): + raise HTTPException(status_code=400, detail="Invalid DMR URL index") + base = urls[url_idx] + idx = url_idx + + # Always append the engine prefix for OpenAI-compatible endpoints + if not base.rstrip("/").endswith(DMR_ENGINE_SUFFIX): + base = base.rstrip("/") + DMR_ENGINE_SUFFIX + return base, idx + + +########################################## +# +# API routes +# +########################################## @router.get("/config") async def get_config(request: Request, user=Depends(get_admin_user)): + """Get DMR configuration""" return { "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, "DMR_BASE_URLS": request.app.state.config.DMR_BASE_URLS, - "DMR_API_KEYS": request.app.state.config.DMR_API_KEYS, "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, } @@ -43,25 +175,17 @@ async def get_config(request: Request, user=Depends(get_admin_user)): class DMRConfigForm(BaseModel): ENABLE_DMR_API: Optional[bool] = None DMR_BASE_URLS: list[str] - DMR_API_KEYS: list[str] = [] DMR_API_CONFIGS: dict = {} @router.post("/config/update") async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)): + """Update DMR configuration""" request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API request.app.state.config.DMR_BASE_URLS = form_data.DMR_BASE_URLS - request.app.state.config.DMR_API_KEYS = form_data.DMR_API_KEYS request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS - if len(request.app.state.config.DMR_API_KEYS) != len(request.app.state.config.DMR_BASE_URLS): - if len(request.app.state.config.DMR_API_KEYS) > len(request.app.state.config.DMR_BASE_URLS): - request.app.state.config.DMR_API_KEYS = request.app.state.config.DMR_API_KEYS[: len(request.app.state.config.DMR_BASE_URLS)] - else: - request.app.state.config.DMR_API_KEYS += [""] * ( - len(request.app.state.config.DMR_BASE_URLS) - len(request.app.state.config.DMR_API_KEYS) - ) - + # Clean up configs for non-existent URLs keys = list(map(str, range(len(request.app.state.config.DMR_BASE_URLS)))) request.app.state.config.DMR_API_CONFIGS = { k: v for k, v in request.app.state.config.DMR_API_CONFIGS.items() if k in keys @@ -70,41 +194,253 @@ async def update_config(request: Request, form_data: DMRConfigForm, user=Depends return { "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, "DMR_BASE_URLS": request.app.state.config.DMR_BASE_URLS, - "DMR_API_KEYS": request.app.state.config.DMR_API_KEYS, "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, } +class ConnectionVerificationForm(BaseModel): + url: str + + @router.post("/verify") async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)): - return await openai.verify_connection(form_data, user) + """Verify connection to DMR backend""" + url = form_data.url + + # Append engine suffix if not present + if not url.rstrip("/").endswith(DMR_ENGINE_SUFFIX): + url = url.rstrip("/") + DMR_ENGINE_SUFFIX + + try: + response = await send_get_request(f"{url}/models", user=user) + if response is not None: + return {"status": "success", "message": "Connection verified"} + else: + raise HTTPException(status_code=400, detail="Failed to connect to DMR backend") + except Exception as e: + log.exception(f"DMR connection verification failed: {e}") + raise HTTPException(status_code=400, detail=f"Connection verification failed: {e}") @router.get("/models") @router.get("/models/{url_idx}") async def get_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)): - with _dmr_context(request): - return await openai.get_models(request, url_idx=url_idx, user=user) + """Get available models from DMR backend""" + url, idx = get_dmr_base_url(request, url_idx) + + response = await send_get_request(f"{url}/models", user=user) + if response is None: + raise HTTPException(status_code=500, detail="Failed to fetch models from DMR backend") + return response @router.post("/chat/completions") -async def generate_chat_completion(request: Request, form_data: dict, user=Depends(get_verified_user)): - with _dmr_context(request): - return await openai.generate_chat_completion(request, form_data, user=user) +async def generate_chat_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user), + bypass_filter: Optional[bool] = False, +): + """Generate chat completions using DMR backend""" + url, idx = get_dmr_base_url(request) + + log.debug(f"DMR chat_completions: model = {form_data.get('model', 'NO_MODEL')}") + + # Resolve model ID if needed + if "model" in form_data: + models = await get_all_models(request, user=user) + for m in models.get("data", []): + if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]: + form_data["model"] = m["id"] + break + + return await send_post_request( + f"{url}/chat/completions", + form_data, + stream=form_data.get("stream", False), + user=user + ) @router.post("/completions") -async def completions(request: Request, form_data: dict, user=Depends(get_verified_user)): - with _dmr_context(request): - return await openai.completions(request, form_data, user=user) +async def generate_completion( + request: Request, + form_data: dict, + user=Depends(get_verified_user), +): + """Generate completions using DMR backend""" + url, idx = get_dmr_base_url(request) + + # Resolve model ID if needed + if "model" in form_data: + models = await get_all_models(request, user=user) + for m in models.get("data", []): + if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]: + form_data["model"] = m["id"] + break + + return await send_post_request( + f"{url}/completions", + form_data, + stream=form_data.get("stream", False), + user=user + ) @router.post("/embeddings") async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)): - with _dmr_context(request): - return await openai.embeddings(request, form_data, user=user) + """Generate embeddings using DMR backend""" + url, idx = get_dmr_base_url(request) + + return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user) +# OpenAI-compatible endpoints +@router.get("/v1/models") +@router.get("/v1/models/{url_idx}") +async def get_openai_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)): + """Get available models from DMR backend (OpenAI-compatible)""" + return await get_models(request, url_idx, user) + + +@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions/{url_idx}") +async def generate_openai_chat_completion( + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user) +): + """Generate chat completions using DMR backend (OpenAI-compatible)""" + if url_idx is not None: + url, idx = get_dmr_base_url(request, url_idx) + return await send_post_request( + f"{url}/chat/completions", + form_data, + stream=form_data.get("stream", False), + user=user + ) + else: + return await generate_chat_completion(request, form_data, user) + + +@router.post("/v1/completions") +@router.post("/v1/completions/{url_idx}") +async def generate_openai_completion( + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user) +): + """Generate completions using DMR backend (OpenAI-compatible)""" + if url_idx is not None: + url, idx = get_dmr_base_url(request, url_idx) + return await send_post_request( + f"{url}/completions", + form_data, + stream=form_data.get("stream", False), + user=user + ) + else: + return await generate_completion(request, form_data, user) + + +@router.post("/v1/embeddings") +@router.post("/v1/embeddings/{url_idx}") +async def generate_openai_embeddings( + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user) +): + """Generate embeddings using DMR backend (OpenAI-compatible)""" + if url_idx is not None: + url, idx = get_dmr_base_url(request, url_idx) + return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user) + else: + return await embeddings(request, form_data, user) + + +# Internal utility for Open WebUI model aggregation async def get_all_models(request: Request, user: UserModel = None): - with _dmr_context(request): - return await openai.get_all_models.__wrapped__(request, user) + """ + Fetch all models from the DMR backend in OpenAI-compatible format for internal use. + Returns: dict with 'data' key (list of models) + """ + try: + url, idx = get_dmr_base_url(request) + + response = await send_get_request(f"{url}/models", user=user) + if response is None: + return {"data": []} + + # Ensure response is in correct format + if isinstance(response, dict) and "data" in response: + # Transform models to include Open WebUI required fields + models = [] + for m in response["data"]: + # Ensure each model has a 'name' field for frontend compatibility + if "name" not in m: + m["name"] = m["id"] + + # Add Open WebUI specific fields + model = { + "id": m["id"], + "name": m["name"], + "object": m.get("object", "model"), + "created": m.get("created", int(time.time())), + "owned_by": "docker", + "dmr": m, # Store original DMR model data + "connection_type": "local", + "tags": [], + } + models.append(model) + + return {"data": models} + elif isinstance(response, list): + # Convert list to OpenAI format with Open WebUI fields + models = [] + for m in response: + if isinstance(m, str): + model = { + "id": m, + "name": m, + "object": "model", + "created": int(time.time()), + "owned_by": "docker", + "dmr": {"id": m, "name": m}, + "connection_type": "local", + "tags": [], + } + elif isinstance(m, dict): + model_id = m.get("id") or m.get("name") or str(m) + model = { + "id": model_id, + "name": m.get("name", model_id), + "object": m.get("object", "model"), + "created": m.get("created", int(time.time())), + "owned_by": "docker", + "dmr": m, + "connection_type": "local", + "tags": [], + } + models.append(model) + return {"data": models} + else: + # Fallback: wrap in data with Open WebUI fields + if response: + model = { + "id": str(response), + "name": str(response), + "object": "model", + "created": int(time.time()), + "owned_by": "docker", + "dmr": response, + "connection_type": "local", + "tags": [], + } + return {"data": [model]} + return {"data": []} + except Exception as e: + log.exception(f"DMR get_all_models failed: {e}") + return {"data": []} diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 268c910e3..5313d78d9 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -31,6 +31,10 @@ from open_webui.routers.ollama import ( generate_chat_completion as generate_ollama_chat_completion, ) +from open_webui.routers.docker_model_runner import ( + generate_chat_completion as generate_docker_chat_completion, +) + from open_webui.routers.pipelines import ( process_pipeline_inlet_filter, process_pipeline_outlet_filter, @@ -274,6 +278,13 @@ async def generate_chat_completion( ) else: return convert_response_ollama_to_openai(response) + if model.get("owned_by") == "docker": + # Using DMR endpoints + return await generate_docker_chat_completion( + request=request, + form_data=form_data, + user=user, + ) else: return await generate_openai_chat_completion( request=request, diff --git a/docker-compose.yaml b/docker-compose.yaml index 74249febd..d188713ed 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -25,6 +25,8 @@ services: environment: - 'OLLAMA_BASE_URL=http://ollama:11434' - 'WEBUI_SECRET_KEY=' + - 'DMR_BASE_URL=${DMR_BASE_URL:-http://model-runner.docker.internal:12434}' + - 'ENABLE_DMR_API=${ENABLE_DMR_API:-false}' extra_hosts: - host.docker.internal:host-gateway restart: unless-stopped