refac: litellm model name validation

This commit is contained in:
Timothy J. Baek 2024-04-21 18:25:53 -05:00
parent 5997774ab8
commit 4651db8c09

View File

@ -12,7 +12,7 @@ import json
import time
import requests
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from utils.utils import get_verified_user, get_current_user, get_admin_user
@ -25,6 +25,7 @@ log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
from litellm.utils import get_llm_provider
import asyncio
import subprocess
@ -165,6 +166,8 @@ class LiteLLMConfigForm(BaseModel):
model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
@ -236,21 +239,28 @@ class AddLiteLLMModelForm(BaseModel):
model_name: str
litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new")
async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
# TODO: Validate model form
try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump())
app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel):