From e82027310dbf7a6ee6ad54c332c5bc2f012fe0b4 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 12 Jun 2024 13:34:34 -0700 Subject: [PATCH] fix --- backend/main.py | 58 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/backend/main.py b/backend/main.py index e2771187a..de8827d12 100644 --- a/backend/main.py +++ b/backend/main.py @@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, "stream": False, } - payload = filter_pipeline(payload, user) + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e + model = app.state.MODELS[task_model_id] response = None @@ -326,16 +330,19 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): print(data["tool_ids"]) for tool_id in data["tool_ids"]: print(tool_id) - response = await get_function_call_response( - messages=data["messages"], - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - ) + try: + response = await get_function_call_response( + messages=data["messages"], + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) - if response: - context += ("\n" if context != "" else "") + response + if response: + context += ("\n" if context != "" else "") + response + except Exception as e: + print(f"Error: {e}") del data["tool_ids"] print(f"tool_context: {context}") @@ -767,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): } print(payload) - payload = filter_pipeline(payload, user) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion( @@ -824,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) } print(payload) - payload = filter_pipeline(payload, user) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion( @@ -861,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ print(model_id) template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - return await get_function_call_response( - form_data["messages"], form_data["tool_id"], template, model_id, user - ) + try: + context = await get_function_call_response( + form_data["messages"], form_data["tool_id"], template, model_id, user + ) + return context + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) @app.post("/api/chat/completions")