From 46c4da48642ccb670c5a696c0d8c62572d8b2b4c Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 25 Feb 2025 01:00:29 -0800 Subject: [PATCH] enh: "stream" hook --- backend/open_webui/utils/filter.py | 7 +++- backend/open_webui/utils/middleware.py | 48 +++++++++++++++++++++----- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index de51bd46e..8ff12bdf2 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -61,7 +61,12 @@ async def process_filter_functions( try: # Prepare parameters sig = inspect.signature(handler) - params = {"body": form_data} | { + + params = {"body": form_data} + if filter_type == "stream": + params = {"event": form_data} + + params = params | { k: v for k, v in { **extra_params, diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 8c82b7074..bb8c33d6d 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -1048,6 +1048,21 @@ async def process_chat_response( ): return response + extra_params = { + "__event_emitter__": event_emitter, + "__event_call__": event_caller, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + "__model__": metadata.get("model"), + } + filter_ids = get_sorted_filter_ids(form_data.get("model")) + # Streaming response if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. @@ -1402,16 +1417,12 @@ async def process_chat_response( ("reasoning", "/reasoning"), ("thought", "/thought"), ("Thought", "/Thought"), - ("|begin_of_thought|", "|end_of_thought|") + ("|begin_of_thought|", "|end_of_thought|"), ] - code_interpreter_tags = [ - ("code_interpreter", "/code_interpreter") - ] + code_interpreter_tags = [("code_interpreter", "/code_interpreter")] - solution_tags = [ - ("|begin_of_solution|", "|end_of_solution|") - ] + solution_tags = [("|begin_of_solution|", "|end_of_solution|")] try: for event in events: @@ -1455,6 +1466,14 @@ async def process_chat_response( try: data = json.loads(data) + data, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) + if "selected_model_id" in data: model_id = data["selected_model_id"] Chats.upsert_message_to_chat_by_id_and_message_id( @@ -1968,16 +1987,29 @@ async def process_chat_response( return {"status": True, "task_id": task_id} else: - # Fallback to the original response async def stream_wrapper(original_generator, events): def wrap_item(item): return f"data: {item}\n\n" for event in events: + event, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=event, + extra_params=extra_params, + ) yield wrap_item(json.dumps(event)) async for data in original_generator: + data, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) yield data return StreamingResponse(