Simplify DMR_BASE_URL logic

This commit is contained in:
Sergei Shitikov 2025-06-16 22:41:47 +02:00
parent 17c677a285
commit d1534f0626
3 changed files with 30 additions and 83 deletions

View File

@ -904,20 +904,12 @@ ENABLE_DMR_API = PersistentConfig(
"dmr.enable", "dmr.enable",
os.environ.get("ENABLE_DMR_API", "True").lower() == "true", os.environ.get("ENABLE_DMR_API", "True").lower() == "true",
) )
DMR_API_BASE_URL = os.environ.get( DMR_BASE_URL = os.environ.get("DMR_BASE_URL", "http://localhost:12434")
"DMR_API_BASE_URL", "http://localhost:12434" # Remove trailing slash
) DMR_BASE_URL = DMR_BASE_URL[:-1] if DMR_BASE_URL.endswith("/") else DMR_BASE_URL
DMR_BASE_URL = os.environ.get("DMR_BASE_URL", "") DMR_BASE_URL = PersistentConfig(
if DMR_BASE_URL: "DMR_BASE_URL", "dmr.base_url", DMR_BASE_URL
# Remove trailing slash
DMR_BASE_URL = DMR_BASE_URL[:-1] if DMR_BASE_URL.endswith("/") else DMR_BASE_URL
DMR_BASE_URLS = os.environ.get("DMR_BASE_URLS", "")
DMR_BASE_URLS = DMR_BASE_URLS if DMR_BASE_URLS != "" else DMR_BASE_URL or "http://localhost:12434"
DMR_BASE_URLS = [url.strip() for url in DMR_BASE_URLS.split(";")]
DMR_BASE_URLS = PersistentConfig(
"DMR_BASE_URLS", "dmr.base_urls", DMR_BASE_URLS
) )
DMR_API_CONFIGS = PersistentConfig( DMR_API_CONFIGS = PersistentConfig(

View File

@ -115,7 +115,7 @@ from open_webui.config import (
OPENAI_API_CONFIGS, OPENAI_API_CONFIGS,
# Docker Model Runner # Docker Model Runner
ENABLE_DMR_API, ENABLE_DMR_API,
DMR_BASE_URLS, DMR_BASE_URL,
DMR_API_CONFIGS, DMR_API_CONFIGS,
# Direct Connections # Direct Connections
ENABLE_DIRECT_CONNECTIONS, ENABLE_DIRECT_CONNECTIONS,
@ -601,7 +601,7 @@ app.state.OPENAI_MODELS = {}
######################################## ########################################
app.state.config.ENABLE_DMR_API = ENABLE_DMR_API app.state.config.ENABLE_DMR_API = ENABLE_DMR_API
app.state.config.DMR_BASE_URLS = DMR_BASE_URLS app.state.config.DMR_BASE_URL = DMR_BASE_URL
app.state.config.DMR_API_CONFIGS = DMR_API_CONFIGS app.state.config.DMR_API_CONFIGS = DMR_API_CONFIGS
app.state.DMR_MODELS = {} app.state.DMR_MODELS = {}

View File

@ -1,6 +1,6 @@
import logging import logging
import aiohttp import aiohttp
from typing import Optional, Union from typing import Union
from urllib.parse import urlparse from urllib.parse import urlparse
import time import time
@ -135,25 +135,16 @@ async def send_post_request(
) )
def get_dmr_base_url(request: Request, url_idx: Optional[int] = None): def get_dmr_base_url(request: Request):
"""Get DMR base URL with engine suffix""" """Get DMR base URL with engine suffix"""
urls = request.app.state.config.DMR_BASE_URLS base_url = request.app.state.config.DMR_BASE_URL
if not urls: if not base_url:
raise HTTPException(status_code=500, detail="No DMR base URLs configured") raise HTTPException(status_code=500, detail="No DMR base URL configured")
if url_idx is None:
base = urls[0]
idx = 0
else:
if url_idx >= len(urls):
raise HTTPException(status_code=400, detail="Invalid DMR URL index")
base = urls[url_idx]
idx = url_idx
# Always append the engine prefix for OpenAI-compatible endpoints # Always append the engine prefix for OpenAI-compatible endpoints
if not base.rstrip("/").endswith(DMR_ENGINE_SUFFIX): if not base_url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
base = base.rstrip("/") + DMR_ENGINE_SUFFIX base_url = base_url.rstrip("/") + DMR_ENGINE_SUFFIX
return base, idx return base_url
########################################## ##########################################
@ -167,14 +158,14 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"""Get DMR configuration""" """Get DMR configuration"""
return { return {
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
"DMR_BASE_URLS": request.app.state.config.DMR_BASE_URLS, "DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
} }
class DMRConfigForm(BaseModel): class DMRConfigForm(BaseModel):
ENABLE_DMR_API: Optional[bool] = None ENABLE_DMR_API: Optional[bool] = None
DMR_BASE_URLS: list[str] DMR_BASE_URL: str
DMR_API_CONFIGS: dict = {} DMR_API_CONFIGS: dict = {}
@ -182,18 +173,12 @@ class DMRConfigForm(BaseModel):
async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)): async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)):
"""Update DMR configuration""" """Update DMR configuration"""
request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API
request.app.state.config.DMR_BASE_URLS = form_data.DMR_BASE_URLS request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL
request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS
# Clean up configs for non-existent URLs
keys = list(map(str, range(len(request.app.state.config.DMR_BASE_URLS))))
request.app.state.config.DMR_API_CONFIGS = {
k: v for k, v in request.app.state.config.DMR_API_CONFIGS.items() if k in keys
}
return { return {
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API, "ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
"DMR_BASE_URLS": request.app.state.config.DMR_BASE_URLS, "DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS, "DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
} }
@ -223,10 +208,9 @@ async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(
@router.get("/models") @router.get("/models")
@router.get("/models/{url_idx}") async def get_models(request: Request, user=Depends(get_verified_user)):
async def get_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)):
"""Get available models from DMR backend""" """Get available models from DMR backend"""
url, idx = get_dmr_base_url(request, url_idx) url = get_dmr_base_url(request)
response = await send_get_request(f"{url}/models", user=user) response = await send_get_request(f"{url}/models", user=user)
if response is None: if response is None:
@ -242,7 +226,7 @@ async def generate_chat_completion(
bypass_filter: Optional[bool] = False, bypass_filter: Optional[bool] = False,
): ):
"""Generate chat completions using DMR backend""" """Generate chat completions using DMR backend"""
url, idx = get_dmr_base_url(request) url = get_dmr_base_url(request)
log.debug(f"DMR chat_completions: model = {form_data.get('model', 'NO_MODEL')}") log.debug(f"DMR chat_completions: model = {form_data.get('model', 'NO_MODEL')}")
@ -269,7 +253,7 @@ async def generate_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
"""Generate completions using DMR backend""" """Generate completions using DMR backend"""
url, idx = get_dmr_base_url(request) url = get_dmr_base_url(request)
# Resolve model ID if needed # Resolve model ID if needed
if "model" in form_data: if "model" in form_data:
@ -290,75 +274,46 @@ async def generate_completion(
@router.post("/embeddings") @router.post("/embeddings")
async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)): async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)):
"""Generate embeddings using DMR backend""" """Generate embeddings using DMR backend"""
url, idx = get_dmr_base_url(request) url = get_dmr_base_url(request)
return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user) return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user)
# OpenAI-compatible endpoints # OpenAI-compatible endpoints
@router.get("/v1/models") @router.get("/v1/models")
@router.get("/v1/models/{url_idx}") async def get_openai_models(request: Request, user=Depends(get_verified_user)):
async def get_openai_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)):
"""Get available models from DMR backend (OpenAI-compatible)""" """Get available models from DMR backend (OpenAI-compatible)"""
return await get_models(request, url_idx, user) return await get_models(request, user)
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
@router.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion( async def generate_openai_chat_completion(
request: Request, request: Request,
form_data: dict, form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user) user=Depends(get_verified_user)
): ):
"""Generate chat completions using DMR backend (OpenAI-compatible)""" """Generate chat completions using DMR backend (OpenAI-compatible)"""
if url_idx is not None: return await generate_chat_completion(request, form_data, user)
url, idx = get_dmr_base_url(request, url_idx)
return await send_post_request(
f"{url}/chat/completions",
form_data,
stream=form_data.get("stream", False),
user=user
)
else:
return await generate_chat_completion(request, form_data, user)
@router.post("/v1/completions") @router.post("/v1/completions")
@router.post("/v1/completions/{url_idx}")
async def generate_openai_completion( async def generate_openai_completion(
request: Request, request: Request,
form_data: dict, form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user) user=Depends(get_verified_user)
): ):
"""Generate completions using DMR backend (OpenAI-compatible)""" """Generate completions using DMR backend (OpenAI-compatible)"""
if url_idx is not None: return await generate_completion(request, form_data, user)
url, idx = get_dmr_base_url(request, url_idx)
return await send_post_request(
f"{url}/completions",
form_data,
stream=form_data.get("stream", False),
user=user
)
else:
return await generate_completion(request, form_data, user)
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
@router.post("/v1/embeddings/{url_idx}")
async def generate_openai_embeddings( async def generate_openai_embeddings(
request: Request, request: Request,
form_data: dict, form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user) user=Depends(get_verified_user)
): ):
"""Generate embeddings using DMR backend (OpenAI-compatible)""" """Generate embeddings using DMR backend (OpenAI-compatible)"""
if url_idx is not None: return await embeddings(request, form_data, user)
url, idx = get_dmr_base_url(request, url_idx)
return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user)
else:
return await embeddings(request, form_data, user)
# Internal utility for Open WebUI model aggregation # Internal utility for Open WebUI model aggregation
@ -368,7 +323,7 @@ async def get_all_models(request: Request, user: UserModel = None):
Returns: dict with 'data' key (list of models) Returns: dict with 'data' key (list of models)
""" """
try: try:
url, idx = get_dmr_base_url(request) url = get_dmr_base_url(request)
response = await send_get_request(f"{url}/models", user=user) response = await send_get_request(f"{url}/models", user=user)
if response is None: if response is None: