From 4651db8c09d90383fc3c8df5670ebd914c68b8e2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 21 Apr 2024 18:25:53 -0500 Subject: [PATCH] refac: litellm model name validation --- backend/apps/litellm/main.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 40619be2f..52e0c7002 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -12,7 +12,7 @@ import json import time import requests -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing import Optional, List from utils.utils import get_verified_user, get_current_user, get_admin_user @@ -25,6 +25,7 @@ log.setLevel(SRC_LOG_LEVELS["LITELLM"]) from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR +from litellm.utils import get_llm_provider import asyncio import subprocess @@ -165,6 +166,8 @@ class LiteLLMConfigForm(BaseModel): model_list: Optional[List[dict]] = None router_settings: Optional[dict] = None + model_config = ConfigDict(protected_namespaces=()) + @app.post("/config/update") async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)): @@ -236,21 +239,28 @@ class AddLiteLLMModelForm(BaseModel): model_name: str litellm_params: dict + model_config = ConfigDict(protected_namespaces=()) + @app.post("/model/new") async def add_model_to_config( form_data: AddLiteLLMModelForm, user=Depends(get_admin_user) ): - # TODO: Validate model form + try: + get_llm_provider(model=form_data.model_name) + app.state.CONFIG["model_list"].append(form_data.model_dump()) - app.state.CONFIG["model_list"].append(form_data.model_dump()) + with open(LITELLM_CONFIG_DIR, "w") as file: + yaml.dump(app.state.CONFIG, file) - with open(LITELLM_CONFIG_DIR, "w") as file: - yaml.dump(app.state.CONFIG, file) + await restart_litellm() - await restart_litellm() - - return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} + return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)} + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) class DeleteLiteLLMModelForm(BaseModel):