diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 26e6f4a86..d505e77f3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -511,7 +511,11 @@ from open_webui.utils.chat import ( chat_action as chat_action_handler, ) from open_webui.utils.embeddings import generate_embeddings -from open_webui.utils.middleware import process_chat_payload, process_chat_response +from open_webui.utils.middleware import ( + build_chat_response_context, + process_chat_payload, + process_chat_response, +) from open_webui.utils.access_control import has_access from open_webui.utils.auth import ( @@ -1379,9 +1383,9 @@ async def check_url(request: Request, call_next): # Fallback to cookie token for browser sessions if request.state.token is None and request.cookies.get("token"): from fastapi.security import HTTPAuthorizationCredentials + request.state.token = HTTPAuthorizationCredentials( - scheme="Bearer", - credentials=request.cookies.get("token") + scheme="Bearer", credentials=request.cookies.get("token") ) request.state.enable_api_keys = app.state.config.ENABLE_API_KEYS @@ -1456,9 +1460,7 @@ app.include_router(functions.router, prefix="/api/v1/functions", tags=["function app.include_router( evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] ) -app.include_router( - analytics.router, prefix="/api/v1/analytics", tags=["analytics"] -) +app.include_router(analytics.router, prefix="/api/v1/analytics", tags=["analytics"]) app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) # SCIM 2.0 API for identity management @@ -1746,9 +1748,11 @@ async def chat_completion( except: pass - return await process_chat_response( - request, response, form_data, user, metadata, model, events, tasks + ctx = build_chat_response_context( + request, form_data, user, model, metadata, tasks, events ) + + return await process_chat_response(response, ctx) except asyncio.CancelledError: log.info("Chat processing was cancelled") try: diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 184ed1267..4150bdff6 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -2416,8 +2416,79 @@ def get_event_emitter_and_caller(metadata): return event_emitter, event_caller -async def background_tasks_handler(request, form_data, user, metadata, tasks): +def build_chat_response_context( + request, form_data, user, model, metadata, tasks, events +): event_emitter, event_caller = get_event_emitter_and_caller(metadata) + return { + "request": request, + "form_data": form_data, + "user": user, + "model": model, + "metadata": metadata, + "tasks": tasks, + "events": events, + "event_emitter": event_emitter, + "event_caller": event_caller, + } + + +def get_response_data(response): + if isinstance(response, list) and len(response) == 1: + # If the response is a single-item list, unwrap it #17213 + response = response[0] + + if isinstance(response, JSONResponse): + if isinstance(response.body, bytes): + try: + response_data = json.loads(response.body.decode("utf-8", "replace")) + except json.JSONDecodeError: + response_data = {"error": {"detail": "Invalid JSON response"}} + else: + response_data = response + elif isinstance(response, dict): + response_data = response + else: + response_data = None + + return response, response_data + + +def merge_events_into_response(response_data, events): + if events and isinstance(events, list): + extra_response = {} + for event in events: + if isinstance(event, dict): + extra_response.update(event) + else: + extra_response[event] = True + + return { + **extra_response, + **response_data, + } + return response_data + + +def build_response_object(response, response_data): + if isinstance(response, dict): + return response_data + if isinstance(response, JSONResponse): + return JSONResponse( + content=response_data, + headers=response.headers, + status_code=response.status_code, + ) + return response + + +async def background_tasks_handler(ctx): + request = ctx["request"] + form_data = ctx["form_data"] + user = ctx["user"] + metadata = ctx["metadata"] + tasks = ctx["tasks"] + event_emitter = ctx["event_emitter"] message = None messages = [] @@ -2633,184 +2704,144 @@ async def background_tasks_handler(request, form_data, user, metadata, tasks): pass -async def process_chat_response( - request, response, form_data, user, metadata, model, events, tasks -): - event_emitter, event_caller = get_event_emitter_and_caller(metadata) +async def non_streaming_chat_response_handler(response, ctx): + request = ctx["request"] - # Non-streaming response - if not isinstance(response, StreamingResponse): - if event_emitter: - try: - if isinstance(response, dict) or isinstance(response, JSONResponse): - if isinstance(response, list) and len(response) == 1: - # If the response is a single-item list, unwrap it #17213 - response = response[0] + user = ctx["user"] + metadata = ctx["metadata"] + events = ctx["events"] - if isinstance(response, JSONResponse) and isinstance( - response.body, bytes - ): - try: - response_data = json.loads( - response.body.decode("utf-8", "replace") - ) - except json.JSONDecodeError: - response_data = { - "error": {"detail": "Invalid JSON response"} - } - else: - response_data = response + event_emitter = ctx["event_emitter"] - if "error" in response_data: - error = response_data.get("error") + response, response_data = get_response_data(response) + if response_data is None: + return response - if isinstance(error, dict): - error = error.get("detail", error) - else: - error = str(error) + if event_emitter: + try: + if "error" in response_data: + error = response_data.get("error") - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "error": {"content": error}, + if isinstance(error, dict): + error = error.get("detail", error) + else: + error = str(error) + + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "error": {"content": error}, + }, + ) + if isinstance(error, str) or isinstance(error, dict): + await event_emitter( + { + "type": "chat:message:error", + "data": {"error": {"content": error}}, + } + ) + + if "selected_model_id" in response_data: + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "selectedModelId": response_data["selected_model_id"], + }, + ) + + choices = response_data.get("choices", []) + if choices and choices[0].get("message", {}).get("content"): + content = response_data["choices"][0]["message"]["content"] + + if content: + await event_emitter( + { + "type": "chat:completion", + "data": response_data, + } + ) + + title = Chats.get_chat_title_by_id(metadata["chat_id"]) + + # Use output from backend if provided (OR-compliant backends) + response_output = response_data.get("output") + + await event_emitter( + { + "type": "chat:completion", + "data": { + "done": True, + "content": content, + **( + {"output": response_output} + if response_output + else {} + ), + "title": title, }, - ) - if isinstance(error, str) or isinstance(error, dict): - await event_emitter( + } + ) + + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "role": "assistant", + "content": content, + **({"output": response_output} if response_output else {}), + }, + ) + + # Send a webhook notification if the user is not active + if not Users.is_user_active(user.id): + webhook_url = Users.get_user_webhook_url_by_id(user.id) + if webhook_url: + await post_webhook( + request.app.state.WEBUI_NAME, + webhook_url, + f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", { - "type": "chat:message:error", - "data": {"error": {"content": error}}, - } - ) - - if "selected_model_id" in response_data: - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "selectedModelId": response_data["selected_model_id"], - }, - ) - - choices = response_data.get("choices", []) - if choices and choices[0].get("message", {}).get("content"): - content = response_data["choices"][0]["message"]["content"] - - if content: - await event_emitter( - { - "type": "chat:completion", - "data": response_data, - } - ) - - title = Chats.get_chat_title_by_id(metadata["chat_id"]) - - # Use output from backend if provided (OR-compliant backends) - response_output = response_data.get("output") - - await event_emitter( - { - "type": "chat:completion", - "data": { - "done": True, - "content": content, - **( - {"output": response_output} - if response_output - else {} - ), - "title": title, - }, - } - ) - - # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "role": "assistant", - "content": content, - **( - {"output": response_output} - if response_output - else {} - ), + "action": "chat", + "message": content, + "title": title, + "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", }, ) - # Send a webhook notification if the user is not active - if not Users.is_user_active(user.id): - webhook_url = Users.get_user_webhook_url_by_id(user.id) - if webhook_url: - await post_webhook( - request.app.state.WEBUI_NAME, - webhook_url, - f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", - { - "action": "chat", - "message": content, - "title": title, - "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", - }, - ) + await background_tasks_handler(ctx) - await background_tasks_handler( - request, form_data, user, metadata, tasks - ) + response = build_response_object( + response, merge_events_into_response(response_data, events) + ) + except Exception as e: + log.debug(f"Error occurred while processing request: {e}") + pass - if events and isinstance(events, list): - extra_response = {} - for event in events: - if isinstance(event, dict): - extra_response.update(event) - else: - extra_response[event] = True - - response_data = { - **extra_response, - **response_data, - } - - if isinstance(response, dict): - response = response_data - if isinstance(response, JSONResponse): - response = JSONResponse( - content=response_data, - headers=response.headers, - status_code=response.status_code, - ) - - except Exception as e: - log.debug(f"Error occurred while processing request: {e}") - pass - - return response - else: - if events and isinstance(events, list) and isinstance(response, dict): - extra_response = {} - for event in events: - if isinstance(event, dict): - extra_response.update(event) - else: - extra_response[event] = True - - response = { - **extra_response, - **response, - } - - return response - - # Non standard response - if not any( - content_type in response.headers["Content-Type"] - for content_type in ["text/event-stream", "application/x-ndjson"] - ): return response + if isinstance(response, dict): + response = merge_events_into_response(response_data, events) + + return response + + +async def streaming_chat_response_handler(response, ctx): + request = ctx["request"] + + form_data = ctx["form_data"] + + user = ctx["user"] + model = ctx["model"] + + metadata = ctx["metadata"] + events = ctx["events"] + + event_emitter = ctx["event_emitter"] + event_caller = ctx["event_caller"] + oauth_token = None try: if request.cookies.get("oauth_session_id", None): @@ -2830,6 +2861,7 @@ async def process_chat_response( "__request__": request, "__model__": model, } + filter_functions = [ Functions.get_function_by_id(filter_id) for filter_id in get_sorted_filter_ids( @@ -2837,7 +2869,7 @@ async def process_chat_response( ) ] - # Streaming response + # Standard streaming response handler if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. model_id = form_data.get("model", "") @@ -4148,9 +4180,9 @@ async def process_chat_response( blocking_code = textwrap.dedent( f""" import builtins - + BLOCKED_MODULES = {CODE_INTERPRETER_BLOCKED_MODULES} - + _real_import = builtins.__import__ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): if name.split('.')[0] in BLOCKED_MODULES: @@ -4160,7 +4192,7 @@ async def process_chat_response( f"Direct import of module {{name}} is restricted." ) return _real_import(name, globals, locals, fromlist, level) - + builtins.__import__ = restricted_import """ ) @@ -4351,9 +4383,7 @@ async def process_chat_response( } ) - await background_tasks_handler( - request, form_data, user, metadata, tasks - ) + await background_tasks_handler(ctx) except asyncio.CancelledError: log.warning("Task was cancelled!") await event_emitter({"type": "chat:tasks:cancel"}) @@ -4410,3 +4440,19 @@ async def process_chat_response( headers=dict(response.headers), background=response.background, ) + + +async def process_chat_response(response, ctx): + # Non-streaming response + if not isinstance(response, StreamingResponse): + return await non_streaming_chat_response_handler(response, ctx) + + # Non standard response + if not any( + content_type in response.headers["Content-Type"] + for content_type in ["text/event-stream", "application/x-ndjson"] + ): + return response + + # Streaming response + return await streaming_chat_response_handler(response, ctx)