refac: tools

This commit is contained in:
Timothy Jaeryang Baek 2025-03-26 00:40:24 -07:00
parent 82ca88105c
commit 5a7efad59c
3 changed files with 86 additions and 13 deletions

View File

@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
event_caller = extra_params["__event_call__"]
metadata = extra_params["__metadata__"]
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
@ -189,17 +192,33 @@ async def chat_completion_tools_handler(
tool_function_params = tool_call.get("parameters", {})
try:
spec = tools[tool_function_name].get("spec", {})
tool = tools[tool_function_name]
spec = tool.get("spec", {})
allowed_params = (
spec.get("parameters", {}).get("properties", {}).keys()
)
tool_function = tools[tool_function_name]["callable"]
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in allowed_params
}
tool_output = await tool_function(**tool_function_params)
if tool.get("direct", False):
tool_output = await tool_function(**tool_function_params)
else:
tool_output = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get("session_id", None),
},
}
)
except Exception as e:
tool_output = str(e)
@ -764,12 +783,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}
form_data["metadata"] = metadata
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_specs = form_data.get("tool_specs", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_specs=}")
tools_dict = {}
if tool_ids:
# If tool_ids field is present, then get the tools
tools = get_tools(
tools_dict = get_tools(
request,
tool_ids,
user,
@ -780,20 +805,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
log.info(f"{tools_dict=}")
if tool_specs:
for tool in tool_specs:
callable = tool.pop("callable", None)
tools_dict[tool["name"]] = {
"direct": True,
"callable": callable,
"spec": tool,
}
if tools_dict:
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools.values()
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, tools
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
@ -1774,9 +1809,25 @@ async def process_chat_response(
for k, v in tool_function_params.items()
if k in allowed_params
}
tool_result = await tool_function(
**tool_function_params
)
if tool.get("direct", False):
tool_result = await tool_function(
**tool_function_params
)
else:
tool_result = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get(
"session_id", None
),
},
}
)
except Exception as e:
tool_result = str(e)

View File

@ -1,6 +1,9 @@
import inspect
import logging
import re
import inspect
import uuid
from typing import Any, Awaitable, Callable, get_type_hints
from functools import update_wrapper, partial

View File

@ -203,6 +203,22 @@
};
};
const executeTool = async (data, cb) => {
console.log(data);
// TODO: MCP (SSE) support
// TODO: API Server support
if (cb) {
cb(
JSON.parse(
JSON.stringify({
result: null
})
)
);
}
};
const chatEventHandler = async (event, cb) => {
const chat = $page.url.pathname.includes(`/c/${event.chat_id}`);
@ -256,6 +272,9 @@
if (type === 'execute:python') {
console.log('execute:python', data);
executePythonAsWorker(data.id, data.code, cb);
} else if (type === 'execute:tool') {
console.log('execute:tool', data);
executeTool(data, cb);
} else if (type === 'request:chat:completion') {
console.log(data, $socket.id);
const { session_id, channel, form_data, model } = data;