refac: reasoning tag

This commit is contained in:
Timothy Jaeryang Baek 2025-02-02 20:50:54 -08:00
parent e4fc2e0e51
commit eb1ede119e

View File

@ -8,6 +8,8 @@ from typing import Any, Optional
import random import random
import json import json
import inspect import inspect
import re
from uuid import uuid4 from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -987,6 +989,7 @@ async def process_chat_response(
pass pass
event_emitter = None event_emitter = None
event_caller = None
if ( if (
"session_id" in metadata "session_id" in metadata
and metadata["session_id"] and metadata["session_id"]
@ -996,10 +999,11 @@ async def process_chat_response(
and metadata["message_id"] and metadata["message_id"]
): ):
event_emitter = get_event_emitter(metadata) event_emitter = get_event_emitter(metadata)
event_caller = get_event_call(metadata)
# Non-streaming response
if not isinstance(response, StreamingResponse): if not isinstance(response, StreamingResponse):
if event_emitter: if event_emitter:
if "selected_model_id" in response: if "selected_model_id" in response:
Chats.upsert_message_to_chat_by_id_and_message_id( Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"], metadata["chat_id"],
@ -1064,22 +1068,136 @@ async def process_chat_response(
else: else:
return response return response
# Non standard response
if not any( if not any(
content_type in response.headers["Content-Type"] content_type in response.headers["Content-Type"]
for content_type in ["text/event-stream", "application/x-ndjson"] for content_type in ["text/event-stream", "application/x-ndjson"]
): ):
return response return response
if event_emitter: # Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID. task_id = str(uuid4()) # Create a unique task ID.
# Handle as a background task # Handle as a background task
async def post_response_handler(response, events): async def post_response_handler(response, events):
def serialize_content_blocks(content_blocks):
content = ""
for block in content_blocks:
if block["type"] == "text":
content = f"{content}{block['content'].strip()}\n"
elif block["type"] == "reasoning":
reasoning_display_content = "\n".join(
(f"> {line}" if not line.startswith(">") else line)
for line in block["content"].splitlines()
)
reasoning_duration = block.get("duration", None)
if reasoning_duration:
content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
else:
content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
else:
content = f"{content}{block['type']}: {block['content']}\n"
return content
def tag_content_handler(content_type, tags, content, content_blocks):
def extract_attributes(tag_content):
"""Extract attributes from a tag if they exist."""
attributes = {}
# Match attributes in the format: key="value" (ignores single quotes for simplicity)
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
for key, value in matches:
attributes[key] = value
return attributes
if content_blocks[-1]["type"] == "text":
for tag in tags:
# Match start tag e.g., <tag> or <tag attr="value">
start_tag_pattern = rf"<{tag}(.*?)>"
match = re.search(start_tag_pattern, content)
if match:
# Extract attributes in the tag (if present)
attributes = extract_attributes(match.group(1))
# Remove the start tag from the currently handling text block
content_blocks[-1]["content"] = content_blocks[-1][
"content"
].replace(match.group(0), "")
if not content_blocks[-1]["content"]:
content_blocks.pop()
# Append the new block
content_blocks.append(
{
"type": content_type,
"tag": tag,
"attributes": attributes,
"content": "",
"started_at": time.time(),
}
)
break
elif content_blocks[-1]["type"] == content_type:
tag = content_blocks[-1]["tag"]
# Match end tag e.g., </tag>
end_tag_pattern = rf"</{tag}>"
if re.search(end_tag_pattern, content):
block_content = content_blocks[-1]["content"]
# Strip start and end tags from the content
start_tag_pattern = rf"<{tag}(.*?)>"
block_content = re.sub(
start_tag_pattern, "", block_content
).strip()
block_content = re.sub(
end_tag_pattern, "", block_content
).strip()
if block_content:
content_blocks[-1]["content"] = block_content
content_blocks[-1]["ended_at"] = time.time()
content_blocks[-1]["duration"] = int(
content_blocks[-1]["ended_at"]
- content_blocks[-1]["started_at"]
)
# Reset the content_blocks by appending a new text block
content_blocks.append(
{
"type": "text",
"content": "",
}
)
# Clean processed content
content = re.sub(
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
"",
content,
flags=re.DOTALL,
)
else:
# Remove the block if content is empty
content_blocks.pop()
return content, content_blocks
message = Chats.get_message_by_id_and_message_id( message = Chats.get_message_by_id_and_message_id(
metadata["chat_id"], metadata["message_id"] metadata["chat_id"], metadata["message_id"]
) )
content = message.get("content", "") if message else "" content = message.get("content", "") if message else ""
content_blocks = [
{
"type": "text",
"content": content,
}
]
# We might want to disable this by default
DETECT_REASONING = True
DETECT_CODE_INTERPRETER = True
reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
code_interpreter_tags = ["oi::code_interpreter"]
try: try:
for event in events: for event in events:
@ -1099,16 +1217,6 @@ async def process_chat_response(
}, },
) )
# We might want to disable this by default
detect_reasoning = True
reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
current_tag = None
reasoning_start_time = None
reasoning_content = ""
ongoing_content = ""
async for line in response.body_iterator: async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line data = line
@ -1144,73 +1252,28 @@ async def process_chat_response(
if value: if value:
content = f"{content}{value}" content = f"{content}{value}"
content_blocks[-1]["content"] = (
if detect_reasoning: content_blocks[-1]["content"] + value
for tag in reasoning_tags:
start_tag = f"<{tag}>\n"
end_tag = f"</{tag}>\n"
if start_tag in content:
# Remove the start tag
content = content.replace(start_tag, "")
ongoing_content = content
reasoning_start_time = time.time()
reasoning_content = ""
current_tag = tag
break
if reasoning_start_time is not None:
# Remove the last value from the content
content = content[: -len(value)]
reasoning_content += value
end_tag = f"</{current_tag}>\n"
if end_tag in reasoning_content:
reasoning_end_time = time.time()
reasoning_duration = int(
reasoning_end_time
- reasoning_start_time
)
reasoning_content = (
reasoning_content.strip(
f"<{current_tag}>\n"
)
.strip(end_tag)
.strip()
) )
if reasoning_content: print(f"Content: {content}")
reasoning_display_content = "\n".join( print(f"Content Blocks: {content_blocks}")
(
f"> {line}" if DETECT_REASONING:
if not line.startswith(">") content, content_blocks = tag_content_handler(
else line "reasoning",
) reasoning_tags,
for line in reasoning_content.splitlines() content,
content_blocks,
) )
# Format reasoning with <details> tag if DETECT_CODE_INTERPRETER:
content = f'{ongoing_content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n' content, content_blocks = tag_content_handler(
else: "code_interpreter",
content = "" code_interpreter_tags,
content,
reasoning_start_time = None content_blocks,
else:
reasoning_display_content = "\n".join(
(
f"> {line}"
if not line.startswith(">")
else line
) )
for line in reasoning_content.splitlines()
)
# Show ongoing thought process
content = f'{ongoing_content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
if ENABLE_REALTIME_CHAT_SAVE: if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database # Save message in the database
@ -1218,12 +1281,16 @@ async def process_chat_response(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
"content": content, "content": serialize_content_blocks(
content_blocks
),
}, },
) )
else: else:
data = { data = {
"content": content, "content": serialize_content_blocks(
content_blocks
),
} }
await event_emitter( await event_emitter(
@ -1240,7 +1307,11 @@ async def process_chat_response(
continue continue
title = Chats.get_chat_title_by_id(metadata["chat_id"]) title = Chats.get_chat_title_by_id(metadata["chat_id"])
data = {"done": True, "content": content, "title": title} data = {
"done": True,
"content": serialize_content_blocks(content_blocks),
"title": title,
}
if not ENABLE_REALTIME_CHAT_SAVE: if not ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database # Save message in the database
@ -1248,7 +1319,7 @@ async def process_chat_response(
metadata["chat_id"], metadata["chat_id"],
metadata["message_id"], metadata["message_id"],
{ {
"content": content, "content": serialize_content_blocks(content_blocks),
}, },
) )