From eb1ede119ee3b49cb55d540f42455a63a78550ec Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 2 Feb 2025 20:50:54 -0800 Subject: [PATCH] refac: reasoning tag --- backend/open_webui/utils/middleware.py | 233 ++++++++++++++++--------- 1 file changed, 152 insertions(+), 81 deletions(-) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 5105c8bb4..4cf29fef9 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -8,6 +8,8 @@ from typing import Any, Optional import random import json import inspect +import re + from uuid import uuid4 from concurrent.futures import ThreadPoolExecutor @@ -987,6 +989,7 @@ async def process_chat_response( pass event_emitter = None + event_caller = None if ( "session_id" in metadata and metadata["session_id"] @@ -996,10 +999,11 @@ async def process_chat_response( and metadata["message_id"] ): event_emitter = get_event_emitter(metadata) + event_caller = get_event_call(metadata) + # Non-streaming response if not isinstance(response, StreamingResponse): if event_emitter: - if "selected_model_id" in response: Chats.upsert_message_to_chat_by_id_and_message_id( metadata["chat_id"], @@ -1064,22 +1068,136 @@ async def process_chat_response( else: return response + # Non standard response if not any( content_type in response.headers["Content-Type"] for content_type in ["text/event-stream", "application/x-ndjson"] ): return response - if event_emitter: - + # Streaming response + if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. # Handle as a background task async def post_response_handler(response, events): + def serialize_content_blocks(content_blocks): + content = "" + + for block in content_blocks: + if block["type"] == "text": + content = f"{content}{block['content'].strip()}\n" + elif block["type"] == "reasoning": + reasoning_display_content = "\n".join( + (f"> {line}" if not line.startswith(">") else line) + for line in block["content"].splitlines() + ) + + reasoning_duration = block.get("duration", None) + + if reasoning_duration: + content = f'{content}
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' + else: + content = f'{content}
\nThinking…\n{reasoning_display_content}\n
\n' + + else: + content = f"{content}{block['type']}: {block['content']}\n" + + return content + + def tag_content_handler(content_type, tags, content, content_blocks): + def extract_attributes(tag_content): + """Extract attributes from a tag if they exist.""" + attributes = {} + # Match attributes in the format: key="value" (ignores single quotes for simplicity) + matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content) + for key, value in matches: + attributes[key] = value + return attributes + + if content_blocks[-1]["type"] == "text": + for tag in tags: + # Match start tag e.g., or + start_tag_pattern = rf"<{tag}(.*?)>" + match = re.search(start_tag_pattern, content) + if match: + # Extract attributes in the tag (if present) + attributes = extract_attributes(match.group(1)) + # Remove the start tag from the currently handling text block + content_blocks[-1]["content"] = content_blocks[-1][ + "content" + ].replace(match.group(0), "") + if not content_blocks[-1]["content"]: + content_blocks.pop() + # Append the new block + content_blocks.append( + { + "type": content_type, + "tag": tag, + "attributes": attributes, + "content": "", + "started_at": time.time(), + } + ) + break + elif content_blocks[-1]["type"] == content_type: + tag = content_blocks[-1]["tag"] + # Match end tag e.g., + end_tag_pattern = rf"" + if re.search(end_tag_pattern, content): + block_content = content_blocks[-1]["content"] + # Strip start and end tags from the content + start_tag_pattern = rf"<{tag}(.*?)>" + block_content = re.sub( + start_tag_pattern, "", block_content + ).strip() + block_content = re.sub( + end_tag_pattern, "", block_content + ).strip() + if block_content: + content_blocks[-1]["content"] = block_content + content_blocks[-1]["ended_at"] = time.time() + content_blocks[-1]["duration"] = int( + content_blocks[-1]["ended_at"] + - content_blocks[-1]["started_at"] + ) + # Reset the content_blocks by appending a new text block + content_blocks.append( + { + "type": "text", + "content": "", + } + ) + # Clean processed content + content = re.sub( + rf"<{tag}(.*?)>(.|\n)*?", + "", + content, + flags=re.DOTALL, + ) + else: + # Remove the block if content is empty + content_blocks.pop() + return content, content_blocks + message = Chats.get_message_by_id_and_message_id( metadata["chat_id"], metadata["message_id"] ) + content = message.get("content", "") if message else "" + content_blocks = [ + { + "type": "text", + "content": content, + } + ] + + # We might want to disable this by default + DETECT_REASONING = True + DETECT_CODE_INTERPRETER = True + + reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"] + code_interpreter_tags = ["oi::code_interpreter"] try: for event in events: @@ -1099,16 +1217,6 @@ async def process_chat_response( }, ) - # We might want to disable this by default - detect_reasoning = True - reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"] - current_tag = None - - reasoning_start_time = None - - reasoning_content = "" - ongoing_content = "" - async for line in response.body_iterator: line = line.decode("utf-8") if isinstance(line, bytes) else line data = line @@ -1144,73 +1252,28 @@ async def process_chat_response( if value: content = f"{content}{value}" + content_blocks[-1]["content"] = ( + content_blocks[-1]["content"] + value + ) - if detect_reasoning: - for tag in reasoning_tags: - start_tag = f"<{tag}>\n" - end_tag = f"\n" + print(f"Content: {content}") + print(f"Content Blocks: {content_blocks}") - if start_tag in content: - # Remove the start tag - content = content.replace(start_tag, "") - ongoing_content = content + if DETECT_REASONING: + content, content_blocks = tag_content_handler( + "reasoning", + reasoning_tags, + content, + content_blocks, + ) - reasoning_start_time = time.time() - reasoning_content = "" - - current_tag = tag - break - - if reasoning_start_time is not None: - # Remove the last value from the content - content = content[: -len(value)] - - reasoning_content += value - - end_tag = f"\n" - if end_tag in reasoning_content: - reasoning_end_time = time.time() - reasoning_duration = int( - reasoning_end_time - - reasoning_start_time - ) - reasoning_content = ( - reasoning_content.strip( - f"<{current_tag}>\n" - ) - .strip(end_tag) - .strip() - ) - - if reasoning_content: - reasoning_display_content = "\n".join( - ( - f"> {line}" - if not line.startswith(">") - else line - ) - for line in reasoning_content.splitlines() - ) - - # Format reasoning with
tag - content = f'{ongoing_content}
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' - else: - content = "" - - reasoning_start_time = None - else: - - reasoning_display_content = "\n".join( - ( - f"> {line}" - if not line.startswith(">") - else line - ) - for line in reasoning_content.splitlines() - ) - - # Show ongoing thought process - content = f'{ongoing_content}
\nThinking…\n{reasoning_display_content}\n
\n' + if DETECT_CODE_INTERPRETER: + content, content_blocks = tag_content_handler( + "code_interpreter", + code_interpreter_tags, + content, + content_blocks, + ) if ENABLE_REALTIME_CHAT_SAVE: # Save message in the database @@ -1218,12 +1281,16 @@ async def process_chat_response( metadata["chat_id"], metadata["message_id"], { - "content": content, + "content": serialize_content_blocks( + content_blocks + ), }, ) else: data = { - "content": content, + "content": serialize_content_blocks( + content_blocks + ), } await event_emitter( @@ -1240,7 +1307,11 @@ async def process_chat_response( continue title = Chats.get_chat_title_by_id(metadata["chat_id"]) - data = {"done": True, "content": content, "title": title} + data = { + "done": True, + "content": serialize_content_blocks(content_blocks), + "title": title, + } if not ENABLE_REALTIME_CHAT_SAVE: # Save message in the database @@ -1248,7 +1319,7 @@ async def process_chat_response( metadata["chat_id"], metadata["message_id"], { - "content": content, + "content": serialize_content_blocks(content_blocks), }, )