diff --git a/backend/main.py b/backend/main.py index 04b4ebffc..de8827d12 100644 --- a/backend/main.py +++ b/backend/main.py @@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, "stream": False, } - payload = filter_pipeline(payload, user) + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e + model = app.state.MODELS[task_model_id] response = None @@ -326,16 +330,19 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): print(data["tool_ids"]) for tool_id in data["tool_ids"]: print(tool_id) - response = await get_function_call_response( - messages=data["messages"], - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - ) + try: + response = await get_function_call_response( + messages=data["messages"], + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + ) - if response: - context += ("\n" if context != "" else "") + response + if response: + context += ("\n" if context != "" else "") + response + except Exception as e: + print(f"Error: {e}") del data["tool_ids"] print(f"tool_context: {context}") @@ -472,13 +479,10 @@ def filter_pipeline(payload, user): if r is not None: try: res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) except: pass + if "detail" in res: + raise Exception(r.status_code, res["detail"]) else: pass @@ -489,6 +493,7 @@ def filter_pipeline(payload, user): if "title" in payload: del payload["title"] + return payload @@ -510,7 +515,14 @@ class PipelineMiddleware(BaseHTTPMiddleware): user = get_current_user( get_http_authorization_cred(request.headers.get("Authorization")) ) - data = filter_pipeline(data, user) + + try: + data = filter_pipeline(data, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) modified_body_bytes = json.dumps(data).encode("utf-8") # Replace the request body with the modified one @@ -762,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): } print(payload) - payload = filter_pipeline(payload, user) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion( @@ -819,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) } print(payload) - payload = filter_pipeline(payload, user) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion( @@ -856,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ print(model_id) template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE - return await get_function_call_response( - form_data["messages"], form_data["tool_id"], template, model_id, user - ) + try: + context = await get_function_call_response( + form_data["messages"], form_data["tool_id"], template, model_id, user + ) + return context + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) @app.post("/api/chat/completions")