fixes and improvements

This commit is contained in:
Sergei Shitikov 2025-06-20 16:56:02 +02:00
parent b1be5b6b89
commit 0f70e87416
2 changed files with 531 additions and 183 deletions

View File

@ -84,6 +84,7 @@ log_sources = [
"COMFYUI",
"CONFIG",
"DB",
"DMR",
"IMAGES",
"MAIN",
"MODELS",

View File

@ -1,27 +1,34 @@
import logging
import aiohttp
from typing import Union
from urllib.parse import urlparse
import time
from typing import Optional, Union
import aiohttp
from aiocache import cached
from fastapi import APIRouter, Depends, Request, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from starlette.background import BackgroundTask
from pydantic import BaseModel
from open_webui.models.users import UserModel
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
BYPASS_MODEL_ACCESS_CONTROL,
ENABLE_FORWARD_USER_INFO_HEADERS,
SRC_LOG_LEVELS
)
from open_webui.models.users import UserModel
from open_webui.models.models import Models
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS.get("DMR", logging.INFO))
log.setLevel(SRC_LOG_LEVELS["DMR"])
router = APIRouter()
@ -35,7 +42,7 @@ DMR_ENGINE_SUFFIX = "/engines/llama.cpp/v1"
##########################################
async def send_get_request(url, user: UserModel = None):
"""Send GET request to DMR backend with proper error handling"""
"""Send a GET request to DMR backend with proper error handling"""
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
@ -77,6 +84,7 @@ async def send_post_request(
url: str,
payload: Union[str, bytes, dict],
stream: bool = False,
content_type: Optional[str] = None,
user: UserModel = None,
):
"""Send POST request to DMR backend with proper error handling"""
@ -90,7 +98,7 @@ async def send_post_request(
url,
data=payload if isinstance(payload, (str, bytes)) else aiohttp.JsonPayload(payload),
headers={
"Content-Type": "application/json",
"Content-Type": content_type or "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
@ -107,10 +115,14 @@ async def send_post_request(
r.raise_for_status()
if stream:
response_headers = dict(r.headers)
if content_type:
response_headers["Content-Type"] = content_type
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
headers=response_headers,
background=BackgroundTask(cleanup_response, response=r, session=session),
)
else:
@ -139,7 +151,7 @@ def get_dmr_base_url(request: Request):
"""Get DMR base URL with engine suffix"""
base_url = request.app.state.config.DMR_BASE_URL
if not base_url:
raise HTTPException(status_code=500, detail="No DMR base URL configured")
raise HTTPException(status_code=500, detail="DMR base URL not configured")
# Always append the engine prefix for OpenAI-compatible endpoints
if not base_url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
@ -147,181 +159,17 @@ def get_dmr_base_url(request: Request):
return base_url
##########################################
#
# API routes
#
##########################################
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
"""Get DMR configuration"""
return {
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
}
class DMRConfigForm(BaseModel):
ENABLE_DMR_API: Optional[bool] = None
DMR_BASE_URL: str
DMR_API_CONFIGS: dict = {}
@router.post("/config/update")
async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)):
"""Update DMR configuration"""
request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API
request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL
request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS
return {
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
}
class ConnectionVerificationForm(BaseModel):
url: str
@router.post("/verify")
async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)):
"""Verify connection to DMR backend"""
url = form_data.url
# Append engine suffix if not present
if not url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
url = url.rstrip("/") + DMR_ENGINE_SUFFIX
try:
response = await send_get_request(f"{url}/models", user=user)
if response is not None:
return {"status": "success", "message": "Connection verified"}
else:
raise HTTPException(status_code=400, detail="Failed to connect to DMR backend")
except Exception as e:
log.exception(f"DMR connection verification failed: {e}")
raise HTTPException(status_code=400, detail=f"Connection verification failed: {e}")
@router.get("/models")
async def get_models(request: Request, user=Depends(get_verified_user)):
"""Get available models from DMR backend"""
url = get_dmr_base_url(request)
response = await send_get_request(f"{url}/models", user=user)
if response is None:
raise HTTPException(status_code=500, detail="Failed to fetch models from DMR backend")
return response
@router.post("/chat/completions")
async def generate_chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
):
"""Generate chat completions using DMR backend"""
url = get_dmr_base_url(request)
log.debug(f"DMR chat_completions: model = {form_data.get('model', 'NO_MODEL')}")
# Resolve model ID if needed
if "model" in form_data:
models = await get_all_models(request, user=user)
for m in models.get("data", []):
if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]:
form_data["model"] = m["id"]
break
return await send_post_request(
f"{url}/chat/completions",
form_data,
stream=form_data.get("stream", False),
user=user
)
@router.post("/completions")
async def generate_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
):
"""Generate completions using DMR backend"""
url = get_dmr_base_url(request)
# Resolve model ID if needed
if "model" in form_data:
models = await get_all_models(request, user=user)
for m in models.get("data", []):
if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]:
form_data["model"] = m["id"]
break
return await send_post_request(
f"{url}/completions",
form_data,
stream=form_data.get("stream", False),
user=user
)
@router.post("/embeddings")
async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)):
"""Generate embeddings using DMR backend"""
url = get_dmr_base_url(request)
return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user)
# OpenAI-compatible endpoints
@router.get("/v1/models")
async def get_openai_models(request: Request, user=Depends(get_verified_user)):
"""Get available models from DMR backend (OpenAI-compatible)"""
return await get_models(request, user)
@router.post("/v1/chat/completions")
async def generate_openai_chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""Generate chat completions using DMR backend (OpenAI-compatible)"""
return await generate_chat_completion(request, form_data, user)
@router.post("/v1/completions")
async def generate_openai_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""Generate completions using DMR backend (OpenAI-compatible)"""
return await generate_completion(request, form_data, user)
@router.post("/v1/embeddings")
async def generate_openai_embeddings(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""Generate embeddings using DMR backend (OpenAI-compatible)"""
return await embeddings(request, form_data, user)
# Internal utility for Open WebUI model aggregation
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
"""
Fetch all models from the DMR backend in OpenAI-compatible format for internal use.
Returns: dict with 'data' key (list of models)
"""
log.info("get_all_models() - DMR")
if not request.app.state.config.ENABLE_DMR_API:
return {"data": []}
try:
url = get_dmr_base_url(request)
@ -399,3 +247,502 @@ async def get_all_models(request: Request, user: UserModel = None):
except Exception as e:
log.exception(f"DMR get_all_models failed: {e}")
return {"data": []}
async def get_filtered_models(models, user):
"""Filter models based on user access control"""
if BYPASS_MODEL_ACCESS_CONTROL:
return models.get("data", [])
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
else:
# If no model info found and user is admin, include it
if user.role == "admin":
filtered_models.append(model)
return filtered_models
##########################################
#
# Configuration endpoints
#
##########################################
@router.head("/")
@router.get("/")
async def get_status():
return {"status": True}
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
"""Get DMR configuration"""
return {
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
}
class DMRConfigForm(BaseModel):
ENABLE_DMR_API: Optional[bool] = None
DMR_BASE_URL: str
DMR_API_CONFIGS: dict = {}
@router.post("/config/update")
async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)):
"""Update DMR configuration"""
request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API
request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL
request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS
return {
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
}
class ConnectionVerificationForm(BaseModel):
url: str
@router.post("/verify")
async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)):
"""Verify connection to DMR backend"""
url = form_data.url
# Append engine suffix if not present
if not url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
url = url.rstrip("/") + DMR_ENGINE_SUFFIX
try:
response = await send_get_request(f"{url}/models", user=user)
if response is not None:
return {"status": "success", "message": "Connection verified"}
else:
raise HTTPException(status_code=400, detail="Failed to connect to DMR backend")
except Exception as e:
log.exception(f"DMR connection verification failed: {e}")
raise HTTPException(status_code=400, detail=f"Connection verification failed: {e}")
##########################################
#
# Model endpoints
#
##########################################
@router.get("/api/tags")
async def get_dmr_tags(request: Request, user=Depends(get_verified_user)):
"""Get available models from DMR backend (Ollama-compatible format)"""
models = await get_all_models(request, user=user)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
filtered_models = await get_filtered_models(models, user)
models = {"data": filtered_models}
# Convert to Ollama-compatible format
ollama_models = []
for model in models.get("data", []):
ollama_model = {
"model": model["id"],
"name": model["name"],
"size": model.get("size", 0),
"digest": model.get("digest", ""),
"details": model.get("details", {}),
"expires_at": model.get("expires_at"),
"size_vram": model.get("size_vram", 0),
}
ollama_models.append(ollama_model)
return {"models": ollama_models}
@router.get("/models")
async def get_models(request: Request, user=Depends(get_verified_user)):
"""Get available models from DMR backend (OpenAI-compatible format)"""
models = await get_all_models(request, user=user)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
filtered_models = await get_filtered_models(models, user)
return {"data": filtered_models}
return models
class ModelNameForm(BaseModel):
name: str
@router.post("/api/show")
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
"""Show model information (Ollama-compatible)"""
models = await get_all_models(request, user=user)
# Find the model
model_found = None
for model in models.get("data", []):
if model["id"] == form_data.name or model["name"] == form_data.name:
model_found = model
break
if not model_found:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
# Return model info in Ollama format
return {
"model": model_found["id"],
"details": model_found.get("dmr", {}),
"modelfile": "", # DMR doesn't provide modelfile
"parameters": "", # DMR doesn't provide parameters
"template": "", # DMR doesn't provide template
}
##########################################
#
# Generation endpoints
#
##########################################
class GenerateChatCompletionForm(BaseModel):
model: str
messages: list[dict]
stream: Optional[bool] = False
model_config = ConfigDict(extra="allow")
@router.post("/api/chat")
async def generate_chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
):
"""Generate chat completions using DMR backend (Ollama-compatible)"""
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
metadata = form_data.pop("metadata", None)
try:
completion_form = GenerateChatCompletionForm(**form_data)
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail=str(e))
payload = {**completion_form.model_dump(exclude_none=True)}
if "metadata" in payload:
del payload["metadata"]
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(status_code=403, detail="Model not found")
elif not bypass_filter:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Model not found")
url = get_dmr_base_url(request)
log.debug(f"DMR chat completion: model = {payload.get('model', 'NO_MODEL')}")
return await send_post_request(
f"{url}/chat/completions",
payload,
stream=payload.get("stream", False),
content_type="application/x-ndjson" if payload.get("stream") else None,
user=user,
)
class GenerateCompletionForm(BaseModel):
model: str
prompt: str
stream: Optional[bool] = False
model_config = ConfigDict(extra="allow")
@router.post("/api/generate")
async def generate_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
):
"""Generate completions using DMR backend (Ollama-compatible)"""
try:
completion_form = GenerateCompletionForm(**form_data)
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail=str(e))
payload = {**completion_form.model_dump(exclude_none=True)}
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
payload = apply_model_params_to_body_openai(params, payload)
# Check if user has access to the model
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(status_code=403, detail="Model not found")
else:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Model not found")
url = get_dmr_base_url(request)
return await send_post_request(
f"{url}/completions",
payload,
stream=payload.get("stream", False),
user=user,
)
class GenerateEmbeddingsForm(BaseModel):
model: str
input: Union[str, list[str]]
model_config = ConfigDict(extra="allow")
@router.post("/api/embeddings")
async def embeddings(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""Generate embeddings using DMR backend (Ollama-compatible)"""
try:
embedding_form = GenerateEmbeddingsForm(**form_data)
except Exception as e:
log.exception(e)
raise HTTPException(status_code=400, detail=str(e))
payload = {**embedding_form.model_dump(exclude_none=True)}
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
# Check if user has access to the model
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(status_code=403, detail="Model not found")
else:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Model not found")
url = get_dmr_base_url(request)
return await send_post_request(f"{url}/embeddings", payload, stream=False, user=user)
##########################################
#
# OpenAI-compatible endpoints
#
##########################################
@router.get("/v1/models")
async def get_openai_models(request: Request, user=Depends(get_verified_user)):
"""Get available models from DMR backend (OpenAI-compatible)"""
models = await get_all_models(request, user=user)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
filtered_models = await get_filtered_models(models, user)
models = {"data": filtered_models}
# Convert to OpenAI format
openai_models = []
for model in models.get("data", []):
openai_model = {
"id": model["id"],
"object": "model",
"created": model.get("created", int(time.time())),
"owned_by": "docker",
}
openai_models.append(openai_model)
return {
"object": "list",
"data": openai_models,
}
@router.post("/v1/chat/completions")
async def generate_openai_chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""Generate chat completions using DMR backend (OpenAI-compatible)"""
metadata = form_data.pop("metadata", None)
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(status_code=403, detail="Model not found")
else:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Model not found")
url = get_dmr_base_url(request)
return await send_post_request(
f"{url}/chat/completions",
payload,
stream=payload.get("stream", False),
user=user,
)
@router.post("/v1/completions")
async def generate_openai_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""Generate completions using DMR backend (OpenAI-compatible)"""
payload = {**form_data}
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
payload = apply_model_params_to_body_openai(params, payload)
# Check if user has access to the model
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(status_code=403, detail="Model not found")
else:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Model not found")
url = get_dmr_base_url(request)
return await send_post_request(
f"{url}/completions",
payload,
stream=payload.get("stream", False),
user=user,
)
@router.post("/v1/embeddings")
async def generate_openai_embeddings(
request: Request,
form_data: dict,
user=Depends(get_verified_user)
):
"""Generate embeddings using DMR backend (OpenAI-compatible)"""
payload = {**form_data}
model_id = payload["model"]
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
# Check if user has access to the model
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(status_code=403, detail="Model not found")
else:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Model not found")
url = get_dmr_base_url(request)
return await send_post_request(f"{url}/embeddings", payload, stream=False, user=user)