tool calling refactor

This commit is contained in:
Michael Poluektov 2024-08-12 15:53:47 +01:00
parent 6df6170c44
commit fdc89cbcee

View File

@ -298,30 +298,6 @@ async def get_content_from_response(response) -> Optional[str]:
return content
async def call_tool_from_completion(
result: dict, extra_params: dict, toolkit_module
) -> Optional[str]:
if "name" not in result:
return None
tool = getattr(toolkit_module, result["name"])
try:
# Get the signature of the function
sig = inspect.signature(tool)
params = result["parameters"]
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if inspect.iscoroutinefunction(tool):
return await tool(**params)
else:
return tool(**params)
except Exception as e:
print(f"Error: {e}")
return None
def get_tool_call_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
history = "\n".join(
@ -342,90 +318,6 @@ def get_tool_call_payload(messages, task_model_id, content):
}
async def get_tool_call_response(
messages, files, tool_id, template, task_model_id, user, extra_params
) -> tuple[Optional[str], Optional[dict], bool]:
"""
return: tuple of (function_result, citation, file_handler) where
- function_result: Optional[str] is the result of the tool call if successful
- citation: Optional[dict] is the citation object if the tool has citation
- file_handler: bool, True if tool handles files
"""
tool = Tools.get_tool_by_id(tool_id)
if tool is None:
return None, None, False
tools_specs = json.dumps(tool.specs, indent=2)
log.debug(f"{tool.specs=}")
content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_call_payload(messages, task_model_id, content)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
custom_params = {
**extra_params,
"__model__": app.state.MODELS[task_model_id],
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
}
try:
if hasattr(toolkit_module, "UserValves"):
custom_params["__user__"]["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
file_handler = hasattr(toolkit_module, "file_handler")
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
try:
response = await generate_chat_completions(form_data=payload, user=user)
content = await get_content_from_response(response)
if content is None:
return None, None, False
# Parse the function response
log.debug(f"content: {content}")
result = json.loads(content)
function_result = await call_tool_from_completion(
result, custom_params, toolkit_module
)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
else:
citation = None
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e:
print(f"Error: {e}")
return None, None, False
async def chat_completion_inlets_handler(body, model, extra_params):
skip_files = None
@ -511,6 +403,7 @@ def get_tool_with_custom_params(
return new_tool
# Mutation on extra_params
def get_configured_tools(
tool_ids: list[str], extra_params: dict, user: UserModel
) -> dict[str, dict]:
@ -525,8 +418,7 @@ def get_configured_tools(
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
more_params = {"__id__": tool_id}
custom_params = more_params | extra_params
extra_params["__id__"] = tool_id
has_citation = hasattr(module, "citation") and module.citation
handles_files = hasattr(module, "file_handler") and module.file_handler
if hasattr(module, "valves") and hasattr(module, "Valves"):
@ -534,27 +426,27 @@ def get_configured_tools(
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
custom_params["__user__"]["valves"] = module.UserValves( # type: ignore
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
name = spec["name"]
callable = getattr(module, name)
# convert to function that takes only model params and inserts custom params
custom_callable = get_tool_with_custom_params(callable, custom_params)
custom_callable = get_tool_with_custom_params(callable, extra_params)
tool_dict = {
"spec": spec,
"citation": has_citation,
"file_handler": handles_files,
"toolkit_module": module,
"toolkit_id": tool_id,
"callable": custom_callable,
}
if name in tools:
log.warning(f"Tool {name} already exists in another toolkit!")
mod_name = tools[name]["toolkit_module"].__name__
log.warning(f"Collision between {toolkit} and {mod_name}.")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{name}")
else:
tools[name] = tool_dict
@ -571,40 +463,68 @@ async def chat_completion_tools_handler(
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
if "tool_ids" not in body:
tool_ids = body.pop("tool_ids", None)
if not tool_ids:
return body, {}
log.debug(f"tool_ids: {body['tool_ids']}")
log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}")
kwargs = {
"messages": body["messages"],
"files": body.get("files", []),
"template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
"task_model_id": task_model_id,
"user": user,
"extra_params": extra_params,
log.debug(f"{tool_ids=}")
custom_params = {
**extra_params,
"__model__": app.state.MODELS[task_model_id],
"__messages__": body["messages"],
"__files__": body.get("files", []),
}
configured_tools = get_configured_tools(tool_ids, custom_params, user)
for tool_id in body["tool_ids"]:
log.debug(f"{tool_id=}")
log.info(f"{configured_tools=}")
specs = [tool["spec"] for tool in configured_tools.values()]
tools_specs = json.dumps(specs)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_call_payload(body["messages"], task_model_id, content)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
try:
response = await generate_chat_completions(form_data=payload, user=user)
log.debug(f"{response=}")
content = await get_content_from_response(response)
log.debug(f"{content=}")
if content is None:
return body, {}
result = json.loads(content)
tool_name = result.get("name", None)
if tool_name not in configured_tools:
return body, {}
tool_params = result.get("parameters", {})
toolkit_id = configured_tools[tool_name]["toolkit_id"]
try:
response, citation, file_handler = await get_tool_call_response(
tool_id=tool_id, **kwargs
)
if isinstance(response, str):
contexts.append(response)
if citation:
citations.append(citation)
if file_handler:
skip_files = True
tool_output = await configured_tools[tool_name]["callable"](**tool_params)
except Exception as e:
log.exception(f"Error: {e}")
tool_output = str(e)
if configured_tools[tool_name]["citation"]:
citations.append(
{
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
"document": [tool_output],
"metadata": [{"source": tool_name}],
}
)
if configured_tools[tool_name]["file_handler"]:
skip_files = True
if isinstance(tool_output, str):
contexts.append(tool_output)
except Exception as e:
print(f"Error: {e}")
content = None
del body["tool_ids"]
log.debug(f"tool_contexts: {contexts}")
if skip_files and "files" in body: