mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
Simplify DMR_BASE_URL logic
This commit is contained in:
parent
17c677a285
commit
d1534f0626
@ -904,20 +904,12 @@ ENABLE_DMR_API = PersistentConfig(
|
||||
"dmr.enable",
|
||||
os.environ.get("ENABLE_DMR_API", "True").lower() == "true",
|
||||
)
|
||||
DMR_API_BASE_URL = os.environ.get(
|
||||
"DMR_API_BASE_URL", "http://localhost:12434"
|
||||
)
|
||||
DMR_BASE_URL = os.environ.get("DMR_BASE_URL", "http://localhost:12434")
|
||||
# Remove trailing slash
|
||||
DMR_BASE_URL = DMR_BASE_URL[:-1] if DMR_BASE_URL.endswith("/") else DMR_BASE_URL
|
||||
|
||||
DMR_BASE_URL = os.environ.get("DMR_BASE_URL", "")
|
||||
if DMR_BASE_URL:
|
||||
# Remove trailing slash
|
||||
DMR_BASE_URL = DMR_BASE_URL[:-1] if DMR_BASE_URL.endswith("/") else DMR_BASE_URL
|
||||
|
||||
DMR_BASE_URLS = os.environ.get("DMR_BASE_URLS", "")
|
||||
DMR_BASE_URLS = DMR_BASE_URLS if DMR_BASE_URLS != "" else DMR_BASE_URL or "http://localhost:12434"
|
||||
DMR_BASE_URLS = [url.strip() for url in DMR_BASE_URLS.split(";")]
|
||||
DMR_BASE_URLS = PersistentConfig(
|
||||
"DMR_BASE_URLS", "dmr.base_urls", DMR_BASE_URLS
|
||||
DMR_BASE_URL = PersistentConfig(
|
||||
"DMR_BASE_URL", "dmr.base_url", DMR_BASE_URL
|
||||
)
|
||||
|
||||
DMR_API_CONFIGS = PersistentConfig(
|
||||
|
@ -115,7 +115,7 @@ from open_webui.config import (
|
||||
OPENAI_API_CONFIGS,
|
||||
# Docker Model Runner
|
||||
ENABLE_DMR_API,
|
||||
DMR_BASE_URLS,
|
||||
DMR_BASE_URL,
|
||||
DMR_API_CONFIGS,
|
||||
# Direct Connections
|
||||
ENABLE_DIRECT_CONNECTIONS,
|
||||
@ -601,7 +601,7 @@ app.state.OPENAI_MODELS = {}
|
||||
########################################
|
||||
|
||||
app.state.config.ENABLE_DMR_API = ENABLE_DMR_API
|
||||
app.state.config.DMR_BASE_URLS = DMR_BASE_URLS
|
||||
app.state.config.DMR_BASE_URL = DMR_BASE_URL
|
||||
app.state.config.DMR_API_CONFIGS = DMR_API_CONFIGS
|
||||
|
||||
app.state.DMR_MODELS = {}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import aiohttp
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
from urllib.parse import urlparse
|
||||
import time
|
||||
|
||||
@ -135,25 +135,16 @@ async def send_post_request(
|
||||
)
|
||||
|
||||
|
||||
def get_dmr_base_url(request: Request, url_idx: Optional[int] = None):
|
||||
def get_dmr_base_url(request: Request):
|
||||
"""Get DMR base URL with engine suffix"""
|
||||
urls = request.app.state.config.DMR_BASE_URLS
|
||||
if not urls:
|
||||
raise HTTPException(status_code=500, detail="No DMR base URLs configured")
|
||||
|
||||
if url_idx is None:
|
||||
base = urls[0]
|
||||
idx = 0
|
||||
else:
|
||||
if url_idx >= len(urls):
|
||||
raise HTTPException(status_code=400, detail="Invalid DMR URL index")
|
||||
base = urls[url_idx]
|
||||
idx = url_idx
|
||||
base_url = request.app.state.config.DMR_BASE_URL
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=500, detail="No DMR base URL configured")
|
||||
|
||||
# Always append the engine prefix for OpenAI-compatible endpoints
|
||||
if not base.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
|
||||
base = base.rstrip("/") + DMR_ENGINE_SUFFIX
|
||||
return base, idx
|
||||
if not base_url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
|
||||
base_url = base_url.rstrip("/") + DMR_ENGINE_SUFFIX
|
||||
return base_url
|
||||
|
||||
|
||||
##########################################
|
||||
@ -167,14 +158,14 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
"""Get DMR configuration"""
|
||||
return {
|
||||
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
|
||||
"DMR_BASE_URLS": request.app.state.config.DMR_BASE_URLS,
|
||||
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
|
||||
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
|
||||
}
|
||||
|
||||
|
||||
class DMRConfigForm(BaseModel):
|
||||
ENABLE_DMR_API: Optional[bool] = None
|
||||
DMR_BASE_URLS: list[str]
|
||||
DMR_BASE_URL: str
|
||||
DMR_API_CONFIGS: dict = {}
|
||||
|
||||
|
||||
@ -182,18 +173,12 @@ class DMRConfigForm(BaseModel):
|
||||
async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)):
|
||||
"""Update DMR configuration"""
|
||||
request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API
|
||||
request.app.state.config.DMR_BASE_URLS = form_data.DMR_BASE_URLS
|
||||
request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL
|
||||
request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS
|
||||
|
||||
# Clean up configs for non-existent URLs
|
||||
keys = list(map(str, range(len(request.app.state.config.DMR_BASE_URLS))))
|
||||
request.app.state.config.DMR_API_CONFIGS = {
|
||||
k: v for k, v in request.app.state.config.DMR_API_CONFIGS.items() if k in keys
|
||||
}
|
||||
|
||||
return {
|
||||
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
|
||||
"DMR_BASE_URLS": request.app.state.config.DMR_BASE_URLS,
|
||||
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
|
||||
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
|
||||
}
|
||||
|
||||
@ -223,10 +208,9 @@ async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
@router.get("/models/{url_idx}")
|
||||
async def get_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""Get available models from DMR backend"""
|
||||
url, idx = get_dmr_base_url(request, url_idx)
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
response = await send_get_request(f"{url}/models", user=user)
|
||||
if response is None:
|
||||
@ -242,7 +226,7 @@ async def generate_chat_completion(
|
||||
bypass_filter: Optional[bool] = False,
|
||||
):
|
||||
"""Generate chat completions using DMR backend"""
|
||||
url, idx = get_dmr_base_url(request)
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
log.debug(f"DMR chat_completions: model = {form_data.get('model', 'NO_MODEL')}")
|
||||
|
||||
@ -269,7 +253,7 @@ async def generate_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""Generate completions using DMR backend"""
|
||||
url, idx = get_dmr_base_url(request)
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
# Resolve model ID if needed
|
||||
if "model" in form_data:
|
||||
@ -290,75 +274,46 @@ async def generate_completion(
|
||||
@router.post("/embeddings")
|
||||
async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
"""Generate embeddings using DMR backend"""
|
||||
url, idx = get_dmr_base_url(request)
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user)
|
||||
|
||||
|
||||
# OpenAI-compatible endpoints
|
||||
@router.get("/v1/models")
|
||||
@router.get("/v1/models/{url_idx}")
|
||||
async def get_openai_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
||||
async def get_openai_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""Get available models from DMR backend (OpenAI-compatible)"""
|
||||
return await get_models(request, url_idx, user)
|
||||
return await get_models(request, user)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/v1/chat/completions/{url_idx}")
|
||||
async def generate_openai_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate chat completions using DMR backend (OpenAI-compatible)"""
|
||||
if url_idx is not None:
|
||||
url, idx = get_dmr_base_url(request, url_idx)
|
||||
return await send_post_request(
|
||||
f"{url}/chat/completions",
|
||||
form_data,
|
||||
stream=form_data.get("stream", False),
|
||||
user=user
|
||||
)
|
||||
else:
|
||||
return await generate_chat_completion(request, form_data, user)
|
||||
return await generate_chat_completion(request, form_data, user)
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
@router.post("/v1/completions/{url_idx}")
|
||||
async def generate_openai_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate completions using DMR backend (OpenAI-compatible)"""
|
||||
if url_idx is not None:
|
||||
url, idx = get_dmr_base_url(request, url_idx)
|
||||
return await send_post_request(
|
||||
f"{url}/completions",
|
||||
form_data,
|
||||
stream=form_data.get("stream", False),
|
||||
user=user
|
||||
)
|
||||
else:
|
||||
return await generate_completion(request, form_data, user)
|
||||
return await generate_completion(request, form_data, user)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@router.post("/v1/embeddings/{url_idx}")
|
||||
async def generate_openai_embeddings(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate embeddings using DMR backend (OpenAI-compatible)"""
|
||||
if url_idx is not None:
|
||||
url, idx = get_dmr_base_url(request, url_idx)
|
||||
return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user)
|
||||
else:
|
||||
return await embeddings(request, form_data, user)
|
||||
return await embeddings(request, form_data, user)
|
||||
|
||||
|
||||
# Internal utility for Open WebUI model aggregation
|
||||
@ -368,7 +323,7 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
Returns: dict with 'data' key (list of models)
|
||||
"""
|
||||
try:
|
||||
url, idx = get_dmr_base_url(request)
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
response = await send_get_request(f"{url}/models", user=user)
|
||||
if response is None:
|
||||
|
Loading…
Reference in New Issue
Block a user