This commit is contained in:
Timothy J. Baek 2024-08-17 16:24:11 +02:00
parent c1823b4b73
commit c4946d42e0

View File

@ -283,21 +283,6 @@ def get_filter_function_ids(model):
return filter_ids return filter_ids
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
def get_tool_call_payload(messages, task_model_id, content): def get_tool_call_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages) user_message = get_last_user_message(messages)
history = "\n".join( history = "\n".join(
@ -403,8 +388,8 @@ def get_tool_with_custom_params(
# Mutation on extra_params # Mutation on extra_params
def get_configured_tools( def get_tools(
tool_ids: list[str], extra_params: dict, user: UserModel tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]: ) -> dict[str, dict]:
tools = {} tools = {}
for tool_id in tool_ids: for tool_id in tool_ids:
@ -420,6 +405,7 @@ def get_configured_tools(
extra_params["__id__"] = tool_id extra_params["__id__"] = tool_id
has_citation = hasattr(module, "citation") and module.citation has_citation = hasattr(module, "citation") and module.citation
handles_files = hasattr(module, "file_handler") and module.file_handler handles_files = hasattr(module, "file_handler") and module.file_handler
if hasattr(module, "valves") and hasattr(module, "Valves"): if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {} valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves) module.valves = module.Valves(**valves)
@ -459,35 +445,51 @@ def get_configured_tools(
return tools return tools
async def get_content_from_response(response) -> Optional[str]:
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
return content
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
skip_files = False skip_files = False
contexts = [] contexts = []
citations = [] citations = []
task_model_id = get_task_model_id(body["model"])
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
tool_ids = body.pop("tool_ids", None) tool_ids = body.pop("tool_ids", None)
if not tool_ids: if not tool_ids:
return body, {} return body, {}
log.debug(f"{tool_ids=}") log.debug(f"{tool_ids=}")
custom_params = { custom_params = {
**extra_params, **extra_params,
"__model__": app.state.MODELS[task_model_id], "__model__": app.state.MODELS[task_model_id],
"__messages__": body["messages"], "__messages__": body["messages"],
"__files__": body.get("files", []), "__files__": body.get("files", []),
} }
configured_tools = get_configured_tools(tool_ids, custom_params, user) tools = get_tools(tool_ids, user, custom_params)
log.info(f"{tools=}")
log.info(f"{configured_tools=}") specs = [tool["spec"] for tool in tools.values()]
specs = [tool["spec"] for tool in configured_tools.values()]
tools_specs = json.dumps(specs) tools_specs = json.dumps(specs)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
content = tool_calling_generation_template(template, tools_specs) content = tool_calling_generation_template(template, tools_specs)
payload = get_tool_call_payload(body["messages"], task_model_id, content) payload = get_tool_call_payload(body["messages"], task_model_id, content)
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
except Exception as e: except Exception as e:
@ -503,16 +505,18 @@ async def chat_completion_tools_handler(
result = json.loads(content) result = json.loads(content)
tool_name = result.get("name", None) tool_name = result.get("name", None)
if tool_name not in configured_tools: if tool_name not in tools:
return body, {} return body, {}
tool_params = result.get("parameters", {}) tool_params = result.get("parameters", {})
toolkit_id = configured_tools[tool_name]["toolkit_id"] toolkit_id = tools[tool_name]["toolkit_id"]
try: try:
tool_output = await configured_tools[tool_name]["callable"](**tool_params) tool_output = await tools[tool_name]["callable"](**tool_params)
except Exception as e: except Exception as e:
tool_output = str(e) tool_output = str(e)
if configured_tools[tool_name]["citation"]:
if tools[tool_name]["citation"]:
citations.append( citations.append(
{ {
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"}, "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
@ -520,7 +524,7 @@ async def chat_completion_tools_handler(
"metadata": [{"source": tool_name}], "metadata": [{"source": tool_name}],
} }
) )
if configured_tools[tool_name]["file_handler"]: if tools[tool_name]["file_handler"]:
skip_files = True skip_files = True
if isinstance(tool_output, str): if isinstance(tool_output, str):