diff --git a/backend/main.py b/backend/main.py index 02b75ac71..5fcb0e742 100644 --- a/backend/main.py +++ b/backend/main.py @@ -283,21 +283,6 @@ def get_filter_function_ids(model): return filter_ids -async def get_content_from_response(response) -> Optional[str]: - content = None - if hasattr(response, "body_iterator"): - async for chunk in response.body_iterator: - data = json.loads(chunk.decode("utf-8")) - content = data["choices"][0]["message"]["content"] - - # Cleanup any remaining background tasks if necessary - if response.background is not None: - await response.background() - else: - content = response["choices"][0]["message"]["content"] - return content - - def get_tool_call_payload(messages, task_model_id, content): user_message = get_last_user_message(messages) history = "\n".join( @@ -403,8 +388,8 @@ def get_tool_with_custom_params( # Mutation on extra_params -def get_configured_tools( - tool_ids: list[str], extra_params: dict, user: UserModel +def get_tools( + tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: tools = {} for tool_id in tool_ids: @@ -420,6 +405,7 @@ def get_configured_tools( extra_params["__id__"] = tool_id has_citation = hasattr(module, "citation") and module.citation handles_files = hasattr(module, "file_handler") and module.file_handler + if hasattr(module, "valves") and hasattr(module, "Valves"): valves = Tools.get_tool_valves_by_id(tool_id) or {} module.valves = module.Valves(**valves) @@ -459,35 +445,51 @@ def get_configured_tools( return tools +async def get_content_from_response(response) -> Optional[str]: + content = None + if hasattr(response, "body_iterator"): + async for chunk in response.body_iterator: + data = json.loads(chunk.decode("utf-8")) + content = data["choices"][0]["message"]["content"] + + # Cleanup any remaining background tasks if necessary + if response.background is not None: + await response.background() + else: + content = response["choices"][0]["message"]["content"] + return content + + async def chat_completion_tools_handler( body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: skip_files = False contexts = [] citations = [] - task_model_id = get_task_model_id(body["model"]) + task_model_id = get_task_model_id(body["model"]) # If tool_ids field is present, call the functions tool_ids = body.pop("tool_ids", None) if not tool_ids: return body, {} log.debug(f"{tool_ids=}") + custom_params = { **extra_params, "__model__": app.state.MODELS[task_model_id], "__messages__": body["messages"], "__files__": body.get("files", []), } - configured_tools = get_configured_tools(tool_ids, custom_params, user) + tools = get_tools(tool_ids, user, custom_params) + log.info(f"{tools=}") - log.info(f"{configured_tools=}") - - specs = [tool["spec"] for tool in configured_tools.values()] + specs = [tool["spec"] for tool in tools.values()] tools_specs = json.dumps(specs) template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE content = tool_calling_generation_template(template, tools_specs) payload = get_tool_call_payload(body["messages"], task_model_id, content) + try: payload = filter_pipeline(payload, user) except Exception as e: @@ -503,16 +505,18 @@ async def chat_completion_tools_handler( result = json.loads(content) tool_name = result.get("name", None) - if tool_name not in configured_tools: + if tool_name not in tools: return body, {} tool_params = result.get("parameters", {}) - toolkit_id = configured_tools[tool_name]["toolkit_id"] + toolkit_id = tools[tool_name]["toolkit_id"] + try: - tool_output = await configured_tools[tool_name]["callable"](**tool_params) + tool_output = await tools[tool_name]["callable"](**tool_params) except Exception as e: tool_output = str(e) - if configured_tools[tool_name]["citation"]: + + if tools[tool_name]["citation"]: citations.append( { "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"}, @@ -520,7 +524,7 @@ async def chat_completion_tools_handler( "metadata": [{"source": tool_name}], } ) - if configured_tools[tool_name]["file_handler"]: + if tools[tool_name]["file_handler"]: skip_files = True if isinstance(tool_output, str):