From e7da506add41b0c9f6a63462c6a317a8a0b139a9 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 4 Feb 2025 21:01:53 -0800 Subject: [PATCH] refac: middleware --- backend/open_webui/main.py | 4 +- backend/open_webui/utils/middleware.py | 91 +++++++++++++++----------- src/lib/apis/openai/index.ts | 2 +- 3 files changed, 58 insertions(+), 39 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 265cb10c5..1707c8299 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -889,9 +889,10 @@ async def chat_completion( } form_data["metadata"] = metadata - form_data, events = await process_chat_payload( + form_data, metadata, events = await process_chat_payload( request, form_data, metadata, user, model ) + except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -900,6 +901,7 @@ async def chat_completion( try: response = await chat_completion_handler(request, form_data, user) + return await process_chat_response( request, response, form_data, user, events, metadata, tasks ) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 81989a654..3f14a683e 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -183,7 +183,7 @@ async def chat_completion_filter_functions_handler(request, body, model, extra_p async def chat_completion_tools_handler( - request: Request, body: dict, user: UserModel, models, extra_params: dict + request: Request, body: dict, user: UserModel, models, tools ) -> tuple[dict, dict]: async def get_content_from_response(response) -> Optional[str]: content = None @@ -218,35 +218,15 @@ async def chat_completion_tools_handler( "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, } - # If tool_ids field is present, call the functions - metadata = body.get("metadata", {}) - - tool_ids = metadata.get("tool_ids", None) - log.debug(f"{tool_ids=}") - if not tool_ids: - return body, {} - - skip_files = False - sources = [] - task_model_id = get_task_model_id( body["model"], request.app.state.config.TASK_MODEL, request.app.state.config.TASK_MODEL_EXTERNAL, models, ) - tools = get_tools( - request, - tool_ids, - user, - { - **extra_params, - "__model__": models[task_model_id], - "__messages__": body["messages"], - "__files__": metadata.get("files", []), - }, - ) - log.info(f"{tools=}") + + skip_files = False + sources = [] specs = [tool["spec"] for tool in tools.values()] tools_specs = json.dumps(specs) @@ -281,6 +261,8 @@ async def chat_completion_tools_handler( result = json.loads(content) async def tool_call_handler(tool_call): + nonlocal skip_files + log.debug(f"{tool_call=}") tool_function_name = tool_call.get("name", None) @@ -725,6 +707,12 @@ async def process_chat_payload(request, form_data, metadata, user, model): # Initialize events to store additional event to be sent to the client # Initialize contexts and citation models = request.app.state.MODELS + task_model_id = get_task_model_id( + form_data["model"], + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) events = [] sources = [] @@ -809,15 +797,41 @@ async def process_chat_payload(request, form_data, metadata, user, model): } form_data["metadata"] = metadata - if not form_data["metadata"].get("function_calling") == "native": - # If the function calling is not native, then call the tools function calling handler - try: - form_data, flags = await chat_completion_tools_handler( - request, form_data, user, models, extra_params - ) - sources.extend(flags.get("sources", [])) - except Exception as e: - log.exception(e) + tool_ids = metadata.get("tool_ids", None) + log.debug(f"{tool_ids=}") + + if tool_ids: + # If tool_ids field is present, then get the tools + tools = get_tools( + request, + tool_ids, + user, + { + **extra_params, + "__model__": models[task_model_id], + "__messages__": form_data["messages"], + "__files__": metadata.get("files", []), + }, + ) + log.info(f"{tools=}") + + if metadata.get("function_calling") == "native": + # If the function calling is native, then call the tools function calling handler + metadata["tools"] = tools + form_data["tools"] = [ + {"type": "function", "function": tool.get("spec", {})} + for tool in tools.values() + ] + else: + # If the function calling is not native, then call the tools function calling handler + try: + form_data, flags = await chat_completion_tools_handler( + request, form_data, user, models, tools + ) + sources.extend(flags.get("sources", [])) + + except Exception as e: + log.exception(e) try: form_data, flags = await chat_completion_files_handler(request, form_data, user) @@ -833,11 +847,11 @@ async def process_chat_payload(request, form_data, metadata, user, model): if "document" in source: for doc_idx, doc_context in enumerate(source["document"]): - metadata = source.get("metadata") + doc_metadata = source.get("metadata") doc_source_id = None - if metadata: - doc_source_id = metadata[doc_idx].get("source", source_id) + if doc_metadata: + doc_source_id = doc_metadata[doc_idx].get("source", source_id) if source_id: context_string += f"{doc_source_id if doc_source_id is not None else source_id}{doc_context}\n" @@ -894,12 +908,15 @@ async def process_chat_payload(request, form_data, metadata, user, model): } ) - return form_data, events + return form_data, metadata, events async def process_chat_response( request, response, form_data, user, events, metadata, tasks ): + + print("metadata", metadata) + async def background_tasks_handler(): message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) message = message_map.get(metadata["message_id"]) if message_map else None diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 5ddfbe935..a801bcdbb 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -322,7 +322,7 @@ export const generateOpenAIChatCompletion = async ( return res.json(); }) .catch((err) => { - error = `${err?.detail ?? 'Network Problem'}`; + error = `${err?.detail ?? err}`; return null; });