rename tool calling helpers to use 'tool' instead of 'function'

This commit is contained in:
Michael Poluektov 2024-08-11 14:56:16 +01:00
parent 2efcda837c
commit 790bdcf9fc
2 changed files with 7 additions and 7 deletions

View File

@ -72,7 +72,7 @@ from utils.utils import (
from utils.task import ( from utils.task import (
title_generation_template, title_generation_template,
search_query_generation_template, search_query_generation_template,
tools_function_calling_generation_template, tool_calling_generation_template,
) )
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
@ -322,7 +322,7 @@ async def call_tool_from_completion(
return None return None
def get_function_calling_payload(messages, task_model_id, content): def get_tool_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages) user_message = get_last_user_message(messages)
history = "\n".join( history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
@ -342,7 +342,7 @@ def get_function_calling_payload(messages, task_model_id, content):
} }
async def get_function_call_response( async def get_tool_call_response(
messages, files, tool_id, template, task_model_id, user, extra_params messages, files, tool_id, template, task_model_id, user, extra_params
) -> tuple[Optional[str], Optional[dict], bool]: ) -> tuple[Optional[str], Optional[dict], bool]:
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
@ -350,8 +350,8 @@ async def get_function_call_response(
return None, None, False return None, None, False
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs) content = tool_calling_generation_template(template, tools_specs)
payload = get_function_calling_payload(messages, task_model_id, content) payload = get_tool_calling_payload(messages, task_model_id, content)
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
@ -502,7 +502,7 @@ async def chat_completion_tools_handler(body, user, extra_params):
for tool_id in body["tool_ids"]: for tool_id in body["tool_ids"]:
print(tool_id) print(tool_id)
try: try:
response, citation, file_handler = await get_function_call_response( response, citation, file_handler = await get_tool_call_response(
messages=body["messages"], messages=body["messages"],
files=body.get("files", []), files=body.get("files", []),
tool_id=tool_id, tool_id=tool_id,

View File

@ -121,6 +121,6 @@ def search_query_generation_template(
return template return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: def tool_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs) template = template.replace("{{TOOLS}}", tools_specs)
return template return template