From fdc89cbceeca5286d83dd5717c20052a73e7ab3a Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Mon, 12 Aug 2024 15:53:47 +0100 Subject: [PATCH] tool calling refactor --- backend/main.py | 204 +++++++++++++++--------------------------------- 1 file changed, 62 insertions(+), 142 deletions(-) diff --git a/backend/main.py b/backend/main.py index 50b83f437..e5b7d174a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -298,30 +298,6 @@ async def get_content_from_response(response) -> Optional[str]: return content -async def call_tool_from_completion( - result: dict, extra_params: dict, toolkit_module -) -> Optional[str]: - if "name" not in result: - return None - - tool = getattr(toolkit_module, result["name"]) - try: - # Get the signature of the function - sig = inspect.signature(tool) - params = result["parameters"] - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if inspect.iscoroutinefunction(tool): - return await tool(**params) - else: - return tool(**params) - except Exception as e: - print(f"Error: {e}") - return None - - def get_tool_call_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) history = "\n".join( @@ -342,90 +318,6 @@ def get_tool_call_payload(messages, task_model_id, content): } -async def get_tool_call_response( - messages, files, tool_id, template, task_model_id, user, extra_params -) -> tuple[Optional[str], Optional[dict], bool]: - """ - return: tuple of (function_result, citation, file_handler) where - - function_result: Optional[str] is the result of the tool call if successful - - citation: Optional[dict] is the citation object if the tool has citation - - file_handler: bool, True if tool handles files - """ - tool = Tools.get_tool_by_id(tool_id) - if tool is None: - return None, None, False - - tools_specs = json.dumps(tool.specs, indent=2) - log.debug(f"{tool.specs=}") - content = tool_calling_generation_template(template, tools_specs) - payload = get_tool_call_payload(messages, task_model_id, content) - - try: - payload = filter_pipeline(payload, user) - except Exception as e: - raise e - - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module - - custom_params = { - **extra_params, - "__model__": app.state.MODELS[task_model_id], - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - } - try: - if hasattr(toolkit_module, "UserValves"): - custom_params["__user__"]["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - - except Exception as e: - print(e) - - file_handler = hasattr(toolkit_module, "file_handler") - - if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) - - try: - response = await generate_chat_completions(form_data=payload, user=user) - content = await get_content_from_response(response) - - if content is None: - return None, None, False - - # Parse the function response - log.debug(f"content: {content}") - result = json.loads(content) - - function_result = await call_tool_from_completion( - result, custom_params, toolkit_module - ) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } - else: - citation = None - - # Add the function result to the system prompt - if function_result is not None: - return function_result, citation, file_handler - except Exception as e: - print(f"Error: {e}") - - return None, None, False - - async def chat_completion_inlets_handler(body, model, extra_params): skip_files = None @@ -511,6 +403,7 @@ def get_tool_with_custom_params( return new_tool +# Mutation on extra_params def get_configured_tools( tool_ids: list[str], extra_params: dict, user: UserModel ) -> dict[str, dict]: @@ -525,8 +418,7 @@ def get_configured_tools( module, _ = load_toolkit_module_by_id(tool_id) webui_app.state.TOOLS[tool_id] = module - more_params = {"__id__": tool_id} - custom_params = more_params | extra_params + extra_params["__id__"] = tool_id has_citation = hasattr(module, "citation") and module.citation handles_files = hasattr(module, "file_handler") and module.file_handler if hasattr(module, "valves") and hasattr(module, "Valves"): @@ -534,27 +426,27 @@ def get_configured_tools( module.valves = module.Valves(**valves) if hasattr(module, "UserValves"): - custom_params["__user__"]["valves"] = module.UserValves( # type: ignore + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) for spec in toolkit.specs: name = spec["name"] callable = getattr(module, name) + # convert to function that takes only model params and inserts custom params - custom_callable = get_tool_with_custom_params(callable, custom_params) + custom_callable = get_tool_with_custom_params(callable, extra_params) tool_dict = { "spec": spec, "citation": has_citation, "file_handler": handles_files, - "toolkit_module": module, + "toolkit_id": tool_id, "callable": custom_callable, } if name in tools: log.warning(f"Tool {name} already exists in another toolkit!") - mod_name = tools[name]["toolkit_module"].__name__ - log.warning(f"Collision between {toolkit} and {mod_name}.") + log.warning(f"Collision between {toolkit} and {tool_id}.") log.warning(f"Discarding {toolkit}.{name}") else: tools[name] = tool_dict @@ -571,40 +463,68 @@ async def chat_completion_tools_handler( task_model_id = get_task_model_id(body["model"]) # If tool_ids field is present, call the functions - if "tool_ids" not in body: + tool_ids = body.pop("tool_ids", None) + if not tool_ids: return body, {} - log.debug(f"tool_ids: {body['tool_ids']}") - log.info(f"{get_configured_tools(body['tool_ids'], extra_params, user)=}") - kwargs = { - "messages": body["messages"], - "files": body.get("files", []), - "template": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - "task_model_id": task_model_id, - "user": user, - "extra_params": extra_params, + log.debug(f"{tool_ids=}") + custom_params = { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__messages__": body["messages"], + "__files__": body.get("files", []), } + configured_tools = get_configured_tools(tool_ids, custom_params, user) - for tool_id in body["tool_ids"]: - log.debug(f"{tool_id=}") + log.info(f"{configured_tools=}") + + specs = [tool["spec"] for tool in configured_tools.values()] + tools_specs = json.dumps(specs) + template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE + content = tool_calling_generation_template(template, tools_specs) + payload = get_tool_call_payload(body["messages"], task_model_id, content) + try: + payload = filter_pipeline(payload, user) + except Exception as e: + raise e + + try: + response = await generate_chat_completions(form_data=payload, user=user) + log.debug(f"{response=}") + content = await get_content_from_response(response) + log.debug(f"{content=}") + if content is None: + return body, {} + + result = json.loads(content) + tool_name = result.get("name", None) + if tool_name not in configured_tools: + return body, {} + + tool_params = result.get("parameters", {}) + toolkit_id = configured_tools[tool_name]["toolkit_id"] try: - response, citation, file_handler = await get_tool_call_response( - tool_id=tool_id, **kwargs - ) - - if isinstance(response, str): - contexts.append(response) - - if citation: - citations.append(citation) - - if file_handler: - skip_files = True - + tool_output = await configured_tools[tool_name]["callable"](**tool_params) except Exception as e: - log.exception(f"Error: {e}") + tool_output = str(e) + if configured_tools[tool_name]["citation"]: + citations.append( + { + "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"}, + "document": [tool_output], + "metadata": [{"source": tool_name}], + } + ) + if configured_tools[tool_name]["file_handler"]: + skip_files = True + + if isinstance(tool_output, str): + contexts.append(tool_output) + + except Exception as e: + print(f"Error: {e}") + content = None - del body["tool_ids"] log.debug(f"tool_contexts: {contexts}") if skip_files and "files" in body: