This commit is contained in:
Timothy J. Baek 2024-08-17 17:01:35 +02:00
parent 536b40890a
commit 0ae6ca608c
2 changed files with 11 additions and 6 deletions

View File

@ -72,7 +72,7 @@ from utils.utils import (
from utils.task import ( from utils.task import (
title_generation_template, title_generation_template,
search_query_generation_template, search_query_generation_template,
tool_calling_generation_template, tools_function_calling_generation_template,
) )
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
@ -466,11 +466,12 @@ async def chat_completion_tools_handler(
specs = [tool["spec"] for tool in tools.values()] specs = [tool["spec"] for tool in tools.values()]
tools_specs = json.dumps(specs) tools_specs = json.dumps(specs)
content = tool_calling_generation_template( tools_function_calling_prompt = tools_function_calling_generation_template(
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs
) )
log.info(f"{tools_function_calling_prompt=}")
payload = get_tools_function_calling_payload( payload = get_tools_function_calling_payload(
body["messages"], task_model_id, content body["messages"], task_model_id, tools_function_calling_prompt
) )
try: try:
@ -496,14 +497,18 @@ async def chat_completion_tools_handler(
tool_function_params = result.get("parameters", {}) tool_function_params = result.get("parameters", {})
try: try:
tool_output = await tools[tool_function_name]["callable"](**tool_function_params) tool_output = await tools[tool_function_name]["callable"](
**tool_function_params
)
except Exception as e: except Exception as e:
tool_output = str(e) tool_output = str(e)
if tools[tool_function_name]["citation"]: if tools[tool_function_name]["citation"]:
citations.append( citations.append(
{ {
"source": {"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"}, "source": {
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
},
"document": [tool_output], "document": [tool_output],
"metadata": [{"source": tool_function_name}], "metadata": [{"source": tool_function_name}],
} }

View File

@ -121,6 +121,6 @@ def search_query_generation_template(
return template return template
def tool_calling_generation_template(template: str, tools_specs: str) -> str: def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs) template = template.replace("{{TOOLS}}", tools_specs)
return template return template