mirror of
https://github.com/open-webui/open-webui
synced 2025-06-22 18:07:17 +00:00
fixes and improvements
This commit is contained in:
parent
b1be5b6b89
commit
0f70e87416
@ -84,6 +84,7 @@ log_sources = [
|
||||
"COMFYUI",
|
||||
"CONFIG",
|
||||
"DB",
|
||||
"DMR",
|
||||
"IMAGES",
|
||||
"MAIN",
|
||||
"MODELS",
|
||||
|
@ -1,27 +1,34 @@
|
||||
import logging
|
||||
import aiohttp
|
||||
from typing import Union
|
||||
from urllib.parse import urlparse
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from aiocache import cached
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.background import BackgroundTask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
SRC_LOG_LEVELS
|
||||
)
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS.get("DMR", logging.INFO))
|
||||
log.setLevel(SRC_LOG_LEVELS["DMR"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -35,7 +42,7 @@ DMR_ENGINE_SUFFIX = "/engines/llama.cpp/v1"
|
||||
##########################################
|
||||
|
||||
async def send_get_request(url, user: UserModel = None):
|
||||
"""Send GET request to DMR backend with proper error handling"""
|
||||
"""Send a GET request to DMR backend with proper error handling"""
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
@ -77,6 +84,7 @@ async def send_post_request(
|
||||
url: str,
|
||||
payload: Union[str, bytes, dict],
|
||||
stream: bool = False,
|
||||
content_type: Optional[str] = None,
|
||||
user: UserModel = None,
|
||||
):
|
||||
"""Send POST request to DMR backend with proper error handling"""
|
||||
@ -90,7 +98,7 @@ async def send_post_request(
|
||||
url,
|
||||
data=payload if isinstance(payload, (str, bytes)) else aiohttp.JsonPayload(payload),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Content-Type": content_type or "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
@ -107,10 +115,14 @@ async def send_post_request(
|
||||
r.raise_for_status()
|
||||
|
||||
if stream:
|
||||
response_headers = dict(r.headers)
|
||||
if content_type:
|
||||
response_headers["Content-Type"] = content_type
|
||||
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
headers=response_headers,
|
||||
background=BackgroundTask(cleanup_response, response=r, session=session),
|
||||
)
|
||||
else:
|
||||
@ -139,7 +151,7 @@ def get_dmr_base_url(request: Request):
|
||||
"""Get DMR base URL with engine suffix"""
|
||||
base_url = request.app.state.config.DMR_BASE_URL
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=500, detail="No DMR base URL configured")
|
||||
raise HTTPException(status_code=500, detail="DMR base URL not configured")
|
||||
|
||||
# Always append the engine prefix for OpenAI-compatible endpoints
|
||||
if not base_url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
|
||||
@ -147,181 +159,17 @@ def get_dmr_base_url(request: Request):
|
||||
return base_url
|
||||
|
||||
|
||||
##########################################
|
||||
#
|
||||
# API routes
|
||||
#
|
||||
##########################################
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
"""Get DMR configuration"""
|
||||
return {
|
||||
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
|
||||
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
|
||||
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
|
||||
}
|
||||
|
||||
|
||||
class DMRConfigForm(BaseModel):
|
||||
ENABLE_DMR_API: Optional[bool] = None
|
||||
DMR_BASE_URL: str
|
||||
DMR_API_CONFIGS: dict = {}
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)):
|
||||
"""Update DMR configuration"""
|
||||
request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API
|
||||
request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL
|
||||
request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS
|
||||
|
||||
return {
|
||||
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
|
||||
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
|
||||
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
|
||||
}
|
||||
|
||||
|
||||
class ConnectionVerificationForm(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/verify")
|
||||
async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)):
|
||||
"""Verify connection to DMR backend"""
|
||||
url = form_data.url
|
||||
|
||||
# Append engine suffix if not present
|
||||
if not url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
|
||||
url = url.rstrip("/") + DMR_ENGINE_SUFFIX
|
||||
|
||||
try:
|
||||
response = await send_get_request(f"{url}/models", user=user)
|
||||
if response is not None:
|
||||
return {"status": "success", "message": "Connection verified"}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Failed to connect to DMR backend")
|
||||
except Exception as e:
|
||||
log.exception(f"DMR connection verification failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Connection verification failed: {e}")
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""Get available models from DMR backend"""
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
response = await send_get_request(f"{url}/models", user=user)
|
||||
if response is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch models from DMR backend")
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: Optional[bool] = False,
|
||||
):
|
||||
"""Generate chat completions using DMR backend"""
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
log.debug(f"DMR chat_completions: model = {form_data.get('model', 'NO_MODEL')}")
|
||||
|
||||
# Resolve model ID if needed
|
||||
if "model" in form_data:
|
||||
models = await get_all_models(request, user=user)
|
||||
for m in models.get("data", []):
|
||||
if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]:
|
||||
form_data["model"] = m["id"]
|
||||
break
|
||||
|
||||
return await send_post_request(
|
||||
f"{url}/chat/completions",
|
||||
form_data,
|
||||
stream=form_data.get("stream", False),
|
||||
user=user
|
||||
)
|
||||
|
||||
|
||||
@router.post("/completions")
|
||||
async def generate_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""Generate completions using DMR backend"""
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
# Resolve model ID if needed
|
||||
if "model" in form_data:
|
||||
models = await get_all_models(request, user=user)
|
||||
for m in models.get("data", []):
|
||||
if m.get("id") == form_data["model"] or m.get("name") == form_data["model"]:
|
||||
form_data["model"] = m["id"]
|
||||
break
|
||||
|
||||
return await send_post_request(
|
||||
f"{url}/completions",
|
||||
form_data,
|
||||
stream=form_data.get("stream", False),
|
||||
user=user
|
||||
)
|
||||
|
||||
|
||||
@router.post("/embeddings")
|
||||
async def embeddings(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
"""Generate embeddings using DMR backend"""
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
return await send_post_request(f"{url}/embeddings", form_data, stream=False, user=user)
|
||||
|
||||
|
||||
# OpenAI-compatible endpoints
|
||||
@router.get("/v1/models")
|
||||
async def get_openai_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""Get available models from DMR backend (OpenAI-compatible)"""
|
||||
return await get_models(request, user)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def generate_openai_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate chat completions using DMR backend (OpenAI-compatible)"""
|
||||
return await generate_chat_completion(request, form_data, user)
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def generate_openai_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate completions using DMR backend (OpenAI-compatible)"""
|
||||
return await generate_completion(request, form_data, user)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def generate_openai_embeddings(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate embeddings using DMR backend (OpenAI-compatible)"""
|
||||
return await embeddings(request, form_data, user)
|
||||
|
||||
|
||||
# Internal utility for Open WebUI model aggregation
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
"""
|
||||
Fetch all models from the DMR backend in OpenAI-compatible format for internal use.
|
||||
Returns: dict with 'data' key (list of models)
|
||||
"""
|
||||
log.info("get_all_models() - DMR")
|
||||
|
||||
if not request.app.state.config.ENABLE_DMR_API:
|
||||
return {"data": []}
|
||||
|
||||
try:
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
@ -399,3 +247,502 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
except Exception as e:
|
||||
log.exception(f"DMR get_all_models failed: {e}")
|
||||
return {"data": []}
|
||||
|
||||
|
||||
async def get_filtered_models(models, user):
|
||||
"""Filter models based on user access control"""
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
return models.get("data", [])
|
||||
|
||||
filtered_models = []
|
||||
for model in models.get("data", []):
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
):
|
||||
filtered_models.append(model)
|
||||
else:
|
||||
# If no model info found and user is admin, include it
|
||||
if user.role == "admin":
|
||||
filtered_models.append(model)
|
||||
return filtered_models
|
||||
|
||||
|
||||
##########################################
|
||||
#
|
||||
# Configuration endpoints
|
||||
#
|
||||
##########################################
|
||||
|
||||
@router.head("/")
|
||||
@router.get("/")
|
||||
async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
"""Get DMR configuration"""
|
||||
return {
|
||||
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
|
||||
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
|
||||
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
|
||||
}
|
||||
|
||||
|
||||
class DMRConfigForm(BaseModel):
|
||||
ENABLE_DMR_API: Optional[bool] = None
|
||||
DMR_BASE_URL: str
|
||||
DMR_API_CONFIGS: dict = {}
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_config(request: Request, form_data: DMRConfigForm, user=Depends(get_admin_user)):
|
||||
"""Update DMR configuration"""
|
||||
request.app.state.config.ENABLE_DMR_API = form_data.ENABLE_DMR_API
|
||||
request.app.state.config.DMR_BASE_URL = form_data.DMR_BASE_URL
|
||||
request.app.state.config.DMR_API_CONFIGS = form_data.DMR_API_CONFIGS
|
||||
|
||||
return {
|
||||
"ENABLE_DMR_API": request.app.state.config.ENABLE_DMR_API,
|
||||
"DMR_BASE_URL": request.app.state.config.DMR_BASE_URL,
|
||||
"DMR_API_CONFIGS": request.app.state.config.DMR_API_CONFIGS,
|
||||
}
|
||||
|
||||
|
||||
class ConnectionVerificationForm(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/verify")
|
||||
async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)):
|
||||
"""Verify connection to DMR backend"""
|
||||
url = form_data.url
|
||||
|
||||
# Append engine suffix if not present
|
||||
if not url.rstrip("/").endswith(DMR_ENGINE_SUFFIX):
|
||||
url = url.rstrip("/") + DMR_ENGINE_SUFFIX
|
||||
|
||||
try:
|
||||
response = await send_get_request(f"{url}/models", user=user)
|
||||
if response is not None:
|
||||
return {"status": "success", "message": "Connection verified"}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Failed to connect to DMR backend")
|
||||
except Exception as e:
|
||||
log.exception(f"DMR connection verification failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Connection verification failed: {e}")
|
||||
|
||||
|
||||
##########################################
|
||||
#
|
||||
# Model endpoints
|
||||
#
|
||||
##########################################
|
||||
|
||||
@router.get("/api/tags")
|
||||
async def get_dmr_tags(request: Request, user=Depends(get_verified_user)):
|
||||
"""Get available models from DMR backend (Ollama-compatible format)"""
|
||||
models = await get_all_models(request, user=user)
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
filtered_models = await get_filtered_models(models, user)
|
||||
models = {"data": filtered_models}
|
||||
|
||||
# Convert to Ollama-compatible format
|
||||
ollama_models = []
|
||||
for model in models.get("data", []):
|
||||
ollama_model = {
|
||||
"model": model["id"],
|
||||
"name": model["name"],
|
||||
"size": model.get("size", 0),
|
||||
"digest": model.get("digest", ""),
|
||||
"details": model.get("details", {}),
|
||||
"expires_at": model.get("expires_at"),
|
||||
"size_vram": model.get("size_vram", 0),
|
||||
}
|
||||
ollama_models.append(ollama_model)
|
||||
|
||||
return {"models": ollama_models}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""Get available models from DMR backend (OpenAI-compatible format)"""
|
||||
models = await get_all_models(request, user=user)
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
filtered_models = await get_filtered_models(models, user)
|
||||
return {"data": filtered_models}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
class ModelNameForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/api/show")
|
||||
async def show_model_info(
|
||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
"""Show model information (Ollama-compatible)"""
|
||||
models = await get_all_models(request, user=user)
|
||||
|
||||
# Find the model
|
||||
model_found = None
|
||||
for model in models.get("data", []):
|
||||
if model["id"] == form_data.name or model["name"] == form_data.name:
|
||||
model_found = model
|
||||
break
|
||||
|
||||
if not model_found:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
||||
)
|
||||
|
||||
# Return model info in Ollama format
|
||||
return {
|
||||
"model": model_found["id"],
|
||||
"details": model_found.get("dmr", {}),
|
||||
"modelfile": "", # DMR doesn't provide modelfile
|
||||
"parameters": "", # DMR doesn't provide parameters
|
||||
"template": "", # DMR doesn't provide template
|
||||
}
|
||||
|
||||
|
||||
##########################################
|
||||
#
|
||||
# Generation endpoints
|
||||
#
|
||||
##########################################
|
||||
|
||||
class GenerateChatCompletionForm(BaseModel):
|
||||
model: str
|
||||
messages: list[dict]
|
||||
stream: Optional[bool] = False
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@router.post("/api/chat")
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: Optional[bool] = False,
|
||||
):
|
||||
"""Generate chat completions using DMR backend (Ollama-compatible)"""
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
metadata = form_data.pop("metadata", None)
|
||||
|
||||
try:
|
||||
completion_form = GenerateChatCompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
payload = {**completion_form.model_dump(exclude_none=True)}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
elif not bypass_filter:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
log.debug(f"DMR chat completion: model = {payload.get('model', 'NO_MODEL')}")
|
||||
|
||||
return await send_post_request(
|
||||
f"{url}/chat/completions",
|
||||
payload,
|
||||
stream=payload.get("stream", False),
|
||||
content_type="application/x-ndjson" if payload.get("stream") else None,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
class GenerateCompletionForm(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
stream: Optional[bool] = False
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@router.post("/api/generate")
|
||||
async def generate_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""Generate completions using DMR backend (Ollama-compatible)"""
|
||||
try:
|
||||
completion_form = GenerateCompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
payload = {**completion_form.model_dump(exclude_none=True)}
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
if params:
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
else:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
return await send_post_request(
|
||||
f"{url}/completions",
|
||||
payload,
|
||||
stream=payload.get("stream", False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
class GenerateEmbeddingsForm(BaseModel):
|
||||
model: str
|
||||
input: Union[str, list[str]]
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@router.post("/api/embeddings")
|
||||
async def embeddings(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate embeddings using DMR backend (Ollama-compatible)"""
|
||||
try:
|
||||
embedding_form = GenerateEmbeddingsForm(**form_data)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
payload = {**embedding_form.model_dump(exclude_none=True)}
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
else:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
return await send_post_request(f"{url}/embeddings", payload, stream=False, user=user)
|
||||
|
||||
|
||||
##########################################
|
||||
#
|
||||
# OpenAI-compatible endpoints
|
||||
#
|
||||
##########################################
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def get_openai_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""Get available models from DMR backend (OpenAI-compatible)"""
|
||||
models = await get_all_models(request, user=user)
|
||||
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
filtered_models = await get_filtered_models(models, user)
|
||||
models = {"data": filtered_models}
|
||||
|
||||
# Convert to OpenAI format
|
||||
openai_models = []
|
||||
for model in models.get("data", []):
|
||||
openai_model = {
|
||||
"id": model["id"],
|
||||
"object": "model",
|
||||
"created": model.get("created", int(time.time())),
|
||||
"owned_by": "docker",
|
||||
}
|
||||
openai_models.append(openai_model)
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": openai_models,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def generate_openai_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate chat completions using DMR backend (OpenAI-compatible)"""
|
||||
metadata = form_data.pop("metadata", None)
|
||||
|
||||
payload = {**form_data}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
else:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
return await send_post_request(
|
||||
f"{url}/chat/completions",
|
||||
payload,
|
||||
stream=payload.get("stream", False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def generate_openai_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate completions using DMR backend (OpenAI-compatible)"""
|
||||
payload = {**form_data}
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
if params:
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
else:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
return await send_post_request(
|
||||
f"{url}/completions",
|
||||
payload,
|
||||
stream=payload.get("stream", False),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def generate_openai_embeddings(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
"""Generate embeddings using DMR backend (OpenAI-compatible)"""
|
||||
payload = {**form_data}
|
||||
|
||||
model_id = payload["model"]
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
else:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Model not found")
|
||||
|
||||
url = get_dmr_base_url(request)
|
||||
|
||||
return await send_post_request(f"{url}/embeddings", payload, stream=False, user=user)
|
||||
|
Loading…
Reference in New Issue
Block a user