From c44fc82ecddbb926240276c8a7cf75cc8ad4dd1d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 9 Jun 2024 12:43:54 -0700 Subject: [PATCH] refac: openai --- backend/apps/openai/main.py | 226 ++++++++++++++++++++++-------------- 1 file changed, 137 insertions(+), 89 deletions(-) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index ab24b4113..f685fd1ed 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -345,113 +345,98 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ) -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_verified_user)): +@app.post("/chat/completions") +@app.post("/chat/completions/{url_idx}") +async def generate_chat_completion( + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), +): idx = 0 + payload = {**form_data} - body = await request.body() - # TODO: Remove below after gpt-4-vision fix from Open AI - # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) - payload = None + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id - try: - if "chat/completions" in path: - body = body.decode("utf-8") - body = json.loads(body) + model_info.params = model_info.params.model_dump() - payload = {**body} + if model_info.params: + if model_info.params.get("temperature", None) is not None: + payload["temperature"] = float(model_info.params.get("temperature")) - model_id = body.get("model") - model_info = Models.get_model_by_id(model_id) + if model_info.params.get("top_p", None): + payload["top_p"] = int(model_info.params.get("top_p", None)) - if model_info: - print(model_info) - if model_info.base_model_id: - payload["model"] = model_info.base_model_id + if model_info.params.get("max_tokens", None): + payload["max_tokens"] = int(model_info.params.get("max_tokens", None)) - model_info.params = model_info.params.model_dump() + if model_info.params.get("frequency_penalty", None): + payload["frequency_penalty"] = int( + model_info.params.get("frequency_penalty", None) + ) - if model_info.params: - if model_info.params.get("temperature", None) is not None: - payload["temperature"] = float( - model_info.params.get("temperature") + if model_info.params.get("seed", None): + payload["seed"] = model_info.params.get("seed", None) + + if model_info.params.get("stop", None): + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) - if model_info.params.get("top_p", None): - payload["top_p"] = int(model_info.params.get("top_p", None)) + else: + pass - if model_info.params.get("max_tokens", None): - payload["max_tokens"] = int( - model_info.params.get("max_tokens", None) - ) + model = app.state.MODELS[payload.get("model")] + idx = model["urlIdx"] - if model_info.params.get("frequency_penalty", None): - payload["frequency_penalty"] = int( - model_info.params.get("frequency_penalty", None) - ) + if "pipeline" in model and model.get("pipeline"): + payload["user"] = {"name": user.name, "id": user.id} - if model_info.params.get("seed", None): - payload["seed"] = model_info.params.get("seed", None) + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if payload.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in payload: + payload["max_tokens"] = 4000 + log.debug("Modified payload:", payload) - if model_info.params.get("stop", None): - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - if model_info.params.get("system", None): - # Check if the payload already has a system message - # If not, add a system message to the payload - if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = ( - model_info.params.get("system", None) - + message["content"] - ) - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": model_info.params.get("system", None), - }, - ) - else: - pass - - model = app.state.MODELS[payload.get("model")] - - idx = model["urlIdx"] - - if "pipeline" in model and model.get("pipeline"): - payload["user"] = {"name": user.name, "id": user.id} - - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if payload.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in payload: - payload["max_tokens"] = 4000 - log.debug("Modified payload:", payload) - - # Convert the modified body back to JSON - payload = json.dumps(payload) - - except json.JSONDecodeError as e: - log.error("Error loading request body into a dictionary:", e) + # Convert the modified body back to JSON + payload = json.dumps(payload) print(payload) url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] - target_url = f"{url}/{path}" + print(payload) headers = {} headers["Authorization"] = f"Bearer {key}" @@ -464,9 +449,72 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): try: session = aiohttp.ClientSession(trust_env=True) r = await session.request( - method=request.method, - url=target_url, - data=payload if payload else body, + method="POST", + url=f"{url}/chat/completions", + data=payload, + headers=headers, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + response_data = await r.json() + return response_data + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = await r.json() + print(res) + if "error" in res: + error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except: + error_detail = f"External: {e}" + raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + finally: + if not streaming and session: + if r: + r.close() + await session.close() + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): + idx = 0 + + body = await request.body() + + url = app.state.config.OPENAI_API_BASE_URLS[idx] + key = app.state.config.OPENAI_API_KEYS[idx] + + target_url = f"{url}/{path}" + + headers = {} + headers["Authorization"] = f"Bearer {key}" + headers["Content-Type"] = "application/json" + + r = None + session = None + streaming = False + + try: + session = aiohttp.ClientSession(trust_env=True) + r = await session.request( + method=request.method, + url=target_url, + data=body, headers=headers, )