diff --git a/backend/main.py b/backend/main.py index d8f352a58..e058bc58d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -369,22 +369,22 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): return body, {} -def get_tool_with_custom_params( - tool: Callable, custom_params: dict +def apply_extra_params_to_tool_function( + function: Callable, custom_params: dict ) -> Callable[..., Awaitable]: - sig = inspect.signature(tool) + sig = inspect.signature(function) extra_params = { key: value for key, value in custom_params.items() if key in sig.parameters } - is_coroutine = inspect.iscoroutinefunction(tool) + is_coroutine = inspect.iscoroutinefunction(function) - async def new_tool(**kwargs): + async def new_function(**kwargs): extra_kwargs = kwargs | extra_params if is_coroutine: - return await tool(**extra_kwargs) - return tool(**extra_kwargs) + return await function(**extra_kwargs) + return function(**extra_kwargs) - return new_tool + return new_function # Mutation on extra_params @@ -403,6 +403,7 @@ def get_tools( webui_app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id + has_citation = hasattr(module, "citation") and module.citation handles_files = hasattr(module, "file_handler") and module.file_handler @@ -420,11 +421,12 @@ def get_tools( for val in spec.get("parameters", {}).get("properties", {}).values(): if val["type"] == "str": val["type"] = "string" - name = spec["name"] - callable = getattr(module, name) + function_name = spec["name"] # convert to function that takes only model params and inserts custom params - custom_callable = get_tool_with_custom_params(callable, extra_params) + callable = apply_extra_params_to_tool_function( + getattr(module, function_name), extra_params + ) # TODO: This needs to be a pydantic model tool_dict = { @@ -432,16 +434,16 @@ def get_tools( "citation": has_citation, "file_handler": handles_files, "toolkit_id": tool_id, - "callable": custom_callable, + "callable": callable, } - # TODO: if collision, prepend toolkit name - if name in tools: - log.warning(f"Tool {name} already exists in another toolkit!") - log.warning(f"Collision between {toolkit} and {tool_id}.") - log.warning(f"Discarding {toolkit}.{name}") - else: - tools[name] = tool_dict + # TODO: if collision, prepend toolkit name + if function_name in tools: + log.warning(f"Tool {function_name} already exists in another toolkit!") + log.warning(f"Collision between {toolkit} and {tool_id}.") + log.warning(f"Discarding {toolkit}.{function_name}") + else: + tools[function_name] = tool_dict return tools @@ -487,8 +489,12 @@ async def chat_completion_tools_handler( specs = [tool["spec"] for tool in tools.values()] tools_specs = json.dumps(specs) - content = tool_calling_generation_template(app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs) - payload = get_tools_function_calling_payload(body["messages"], task_model_id, content) + content = tool_calling_generation_template( + app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs + ) + payload = get_tools_function_calling_payload( + body["messages"], task_model_id, content + ) try: payload = filter_pipeline(payload, user) @@ -501,7 +507,6 @@ async def chat_completion_tools_handler( content = await get_content_from_response(response) log.debug(f"{content=}") - if content is None: return body, {} @@ -512,7 +517,7 @@ async def chat_completion_tools_handler( tool_params = result.get("parameters", {}) toolkit_id = tools[tool_name]["toolkit_id"] - + try: tool_output = await tools[tool_name]["callable"](**tool_params) except Exception as e: