enh: block api user with model filter

This commit is contained in:
Timothy J. Baek 2024-08-26 14:24:56 +02:00
parent 7fc049a513
commit b96239fb0b
2 changed files with 25 additions and 0 deletions

View File

@ -737,6 +737,14 @@ async def generate_chat_completion(
del payload["metadata"]
model_id = form_data.model
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=403,
detail="Model not found",
)
model_info = Models.get_model_by_id(model_id)
if model_info:
@ -797,6 +805,14 @@ async def generate_openai_chat_completion(
del payload["metadata"]
model_id = completion_form.model
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=403,
detail="Model not found",
)
model_info = Models.get_model_by_id(model_id)
if model_info:

View File

@ -981,11 +981,20 @@ async def get_models(user=Depends(get_verified_user)):
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Model not found",
)
model = app.state.MODELS[model_id]
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user)