From e427ef767bb52f72a17785c50094370871573c94 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 29 May 2024 11:28:42 -0700 Subject: [PATCH] fix: openai proxy --- backend/apps/openai/main.py | 47 ++++++++++++++++++---------- backend/apps/webui/routers/models.py | 1 - 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 3b55d4d16..29c157308 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -338,21 +338,36 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): model_info.params = model_info.params.model_dump() if model_info.params: - payload["temperature"] = model_info.params.get("temperature", None) - payload["top_p"] = model_info.params.get("top_p", None) - payload["max_tokens"] = model_info.params.get("max_tokens", None) - payload["frequency_penalty"] = model_info.params.get( - "frequency_penalty", None - ) - payload["seed"] = model_info.params.get("seed", None) - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) + if model_info.params.get("temperature", None): + payload["temperature"] = int( + model_info.params.get("temperature") + ) + + if model_info.params.get("top_p", None): + payload["top_p"] = int(model_info.params.get("top_p", None)) + + if model_info.params.get("max_tokens", None): + payload["max_tokens"] = int( + model_info.params.get("max_tokens", None) + ) + + if model_info.params.get("frequency_penalty", None): + payload["frequency_penalty"] = int( + model_info.params.get("frequency_penalty", None) + ) + + if model_info.params.get("seed", None): + payload["seed"] = model_info.params.get("seed", None) + + if model_info.params.get("stop", None): + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) if model_info.params.get("system", None): # Check if the payload already has a system message @@ -376,7 +391,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): else: pass - print(app.state.MODELS) model = app.state.MODELS[payload.get("model")] idx = model["urlIdx"] @@ -442,6 +456,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): if r is not None: try: res = r.json() + print(res) if "error" in res: error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" except: diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py index 363737e25..acc1c6b47 100644 --- a/backend/apps/webui/routers/models.py +++ b/backend/apps/webui/routers/models.py @@ -82,7 +82,6 @@ async def update_model_by_id( else: if form_data.id in request.app.state.MODELS: model = Models.insert_new_model(form_data, user.id) - print(model) if model: return model else: