diff --git a/backend/main.py b/backend/main.py index c45862de6..33b4f5f58 100644 --- a/backend/main.py +++ b/backend/main.py @@ -58,7 +58,6 @@ from config import ( SRC_LOG_LEVELS, WEBHOOK_URL, ENABLE_ADMIN_EXPORT, - MODEL_CONFIG, ) from constants import ERROR_MESSAGES @@ -287,6 +286,38 @@ async def update_model_filter_config( } +class ModelConfig(BaseModel): + id: str + name: str + description: str + vision_capable: bool + + +class SetModelConfigForm(BaseModel): + ollama: List[ModelConfig] + litellm: List[ModelConfig] + openai: List[ModelConfig] + + +@app.post("/api/config/models") +async def update_model_config( + form_data: SetModelConfigForm, user=Depends(get_admin_user) +): + data = form_data.model_dump() + + ollama_app.state.MODEL_CONFIG = data.get("ollama", []) + + openai_app.state.MODEL_CONFIG = data.get("openai", []) + + litellm_app.state.MODEL_CONFIG = data.get("litellm", []) + + return { + "ollama": ollama_app.state.MODEL_CONFIG, + "openai": openai_app.state.MODEL_CONFIG, + "litellm": litellm_app.state.MODEL_CONFIG, + } + + @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return {