enh: builtin tools

This commit is contained in:
Tim Baek
2026-01-07 07:00:32 -05:00
parent 60e916d6c0
commit 2789f6a24d
2 changed files with 142 additions and 59 deletions

View File

@@ -142,6 +142,104 @@ DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")]
DEFAULT_CODE_INTERPRETER_TAGS = [("<code_interpreter>", "</code_interpreter>")]
def get_citation_source_from_tool_result(
tool_name: str,
tool_params: dict,
tool_result: str,
tool_id: str = ""
) -> dict:
"""
Parse a tool's result and convert it to a source dict for citation display.
For web_search: extracts title, link, snippet from each search result.
For other tools: wraps the raw result as a generic source.
"""
try:
if tool_name == "web_search":
# Parse JSON array: [{"title": "...", "link": "...", "snippet": "..."}]
results = json.loads(tool_result)
documents = []
metadata = []
for result in results:
title = result.get("title", "")
link = result.get("link", "")
snippet = result.get("snippet", "")
documents.append(f"{title}\n{snippet}")
metadata.append({
"source": link,
"name": title,
"url": link,
})
return {
"source": {"name": "web_search", "id": "web_search"},
"document": documents,
"metadata": metadata,
}
else:
# Fallback for other tools
return {
"source": {"name": tool_name, "id": tool_id or tool_name},
"document": [str(tool_result)],
"metadata": [{"source": tool_id or tool_name, "parameters": tool_params}],
}
except Exception as e:
log.exception(f"Error parsing tool result for {tool_name}: {e}")
return {
"source": {"name": tool_name, "id": tool_id or tool_name},
"document": [str(tool_result)],
"metadata": [{"source": tool_id or tool_name}],
}
def apply_source_context_to_messages(
request: Request,
messages: list,
sources: list,
user_message: str,
) -> list:
"""
Build source context from citation sources and apply to messages.
Uses RAG template to format context for model consumption.
"""
if not sources or not user_message:
return messages
context_string = ""
citation_idx = {}
for source in sources:
for doc, meta in zip(source.get("document", []), source.get("metadata", [])):
src_id = meta.get("source") or source.get("source", {}).get("id") or "N/A"
if src_id not in citation_idx:
citation_idx[src_id] = len(citation_idx) + 1
src_name = source.get("source", {}).get("name")
context_string += (
f'<source id="{citation_idx[src_id]}"'
+ (f' name="{src_name}"' if src_name else "")
+ f">{doc}</source>\n"
)
context_string = context_string.strip()
if not context_string:
return messages
if RAG_SYSTEM_CONTEXT:
return add_or_update_system_message(
rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message),
messages,
append=True,
)
else:
return add_or_update_user_message(
rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message),
messages,
append=False,
)
def process_tool_result(
request,
tool_function_name,
@@ -1567,6 +1665,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__event_emitter__": event_emitter,
},
features,
model,
)
for name, tool_dict in builtin_tools.items():
if name not in tools_dict:
@@ -1599,58 +1698,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.exception(e)
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
citation_idx_map = {}
for source in sources:
if "document" in source:
for document_text, document_metadata in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
source_id = (
document_metadata.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
if source_id not in citation_idx_map:
citation_idx_map[source_id] = len(citation_idx_map) + 1
context_string += (
f'<source id="{citation_idx_map[source_id]}"'
+ (f' name="{source_name}"' if source_name else "")
+ f">{document_text}</source>\n"
)
context_string = context_string.strip()
if prompt is None:
raise Exception("No user message found")
if context_string != "":
if RAG_SYSTEM_CONTEXT:
# Inject into system message for KV prefix caching
form_data["messages"] = add_or_update_system_message(
rag_template(
request.app.state.config.RAG_TEMPLATE,
context_string,
prompt,
),
form_data["messages"],
append=True,
)
else:
# Inject into user message
form_data["messages"] = add_or_update_user_message(
rag_template(
request.app.state.config.RAG_TEMPLATE,
context_string,
prompt,
),
form_data["messages"],
append=False,
)
if sources and prompt:
form_data["messages"] = apply_source_context_to_messages(
request, form_data["messages"], sources, prompt
)
# If there are citations, add them to the data_items
sources = [
@@ -2977,6 +3028,7 @@ async def process_chat_response(
await stream_body_handler(response, form_data)
tool_call_retries = 0
tool_call_sources = [] # Track citation sources from tool results
while (
len(tool_calls) > 0
@@ -3111,6 +3163,19 @@ async def process_chat_response(
)
)
# Extract citation sources from web_search results
if tool_function_name == "web_search" and tool_result:
try:
citation_source = get_citation_source_from_tool_result(
tool_name=tool_function_name,
tool_params=tool_function_params,
tool_result=tool_result,
tool_id=tool.get("tool_id", "") if tool else ""
)
tool_call_sources.append(citation_source)
except Exception as e:
log.exception(f"Error extracting citation source: {e}")
results.append(
{
"tool_call_id": tool_call_id,
@@ -3136,6 +3201,19 @@ async def process_chat_response(
}
)
# Emit citation sources for UI display
for source in tool_call_sources:
await event_emitter({"type": "source", "data": source})
# Apply source context to messages for model
if tool_call_sources:
user_msg = get_last_user_message(form_data["messages"])
if user_msg:
form_data["messages"] = apply_source_context_to_messages(
request, form_data["messages"], tool_call_sources, user_msg
)
tool_call_sources.clear()
await event_emitter(
{
"type": "chat:completion",