From c6c0bc19d89951cf6fb99adbf9fa6fa442453aff Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 12 Jun 2024 13:31:05 -0700 Subject: [PATCH 1/2] fix: filter pipeline --- backend/main.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/backend/main.py b/backend/main.py index 04b4ebffc..e2771187a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -472,13 +472,10 @@ def filter_pipeline(payload, user): if r is not None: try: res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) except: pass + if "detail" in res: + raise Exception(r.status_code, res["detail"]) else: pass @@ -489,6 +486,7 @@ def filter_pipeline(payload, user): if "title" in payload: del payload["title"] + return payload @@ -510,7 +508,14 @@ class PipelineMiddleware(BaseHTTPMiddleware): user = get_current_user( get_http_authorization_cred(request.headers.get("Authorization")) ) - data = filter_pipeline(data, user) + + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) modified_body_bytes = json.dumps(data).encode("utf-8") # Replace the request body with the modified one From e82027310dbf7a6ee6ad54c332c5bc2f012fe0b4 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 12 Jun 2024 13:34:34 -0700 Subject: [PATCH 2/2] fix --- backend/main.py | 58 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/backend/main.py b/backend/main.py index e2771187a..de8827d12 100644 --- a/backend/main.py +++ b/backend/main.py @@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, "stream": False, } - payload = filter_pipeline(payload, user) + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e + model = app.state.MODELS[task_model_id] response = None @@ -326,16 +330,19 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): print(data["tool_ids"]) for tool_id in data["tool_ids"]: print(tool_id) - response = await get_function_call_response( - messages=data["messages"], - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - ) + try: + response = await get_function_call_response( + messages=data["messages"], + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) - if response: - context += ("\n" if context != "" else "") + response + if response: + context += ("\n" if context != "" else "") + response + except Exception as e: + print(f"Error: {e}") del data["tool_ids"] print(f"tool_context: {context}") @@ -767,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): } print(payload) - payload = filter_pipeline(payload, user) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion( @@ -824,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) } print(payload) - payload = filter_pipeline(payload, user) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion( @@ -861,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ print(model_id) template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - return await get_function_call_response( - form_data["messages"], form_data["tool_id"], template, model_id, user - ) + try: + context = await get_function_call_response( + form_data["messages"], form_data["tool_id"], template, model_id, user + ) + return context + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) @app.post("/api/chat/completions")