diff --git a/backend/main.py b/backend/main.py index 44fdc6298..512f3d006 100644 --- a/backend/main.py +++ b/backend/main.py @@ -297,6 +297,30 @@ async def get_content_from_response(response) -> Optional[str]: return content +async def call_tool_from_completion( + result: dict, extra_params: dict, toolkit_module +) -> Optional[str]: + if "name" not in result: + return None + + tool = getattr(toolkit_module, result["name"]) + try: + # Get the signature of the function + sig = inspect.signature(tool) + params = result["parameters"] + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if inspect.iscoroutinefunction(tool): + return await tool(**params) + else: + return tool(**params) + except Exception as e: + print(f"Error: {e}") + return None + + async def get_function_call_response( messages, files, @@ -306,7 +330,7 @@ async def get_function_call_response( user, __event_emitter__=None, __event_call__=None, -): +) -> tuple[Optional[str], Optional[dict], bool]: tool = Tools.get_tool_by_id(tool_id) if tool is None: return None, None, False @@ -343,7 +367,43 @@ async def get_function_call_response( except Exception as e: raise e - model = app.state.MODELS[task_model_id] + if tool_id in webui_app.state.TOOLS: + toolkit_module = webui_app.state.TOOLS[tool_id] + else: + toolkit_module, _ = load_toolkit_module_by_id(tool_id) + webui_app.state.TOOLS[tool_id] = toolkit_module + + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(toolkit_module, "UserValves"): + __user__["valves"] = toolkit_module.UserValves( + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + except Exception as e: + print(e) + + extra_params = { + "__model__": app.state.MODELS[task_model_id], + "__id__": tool_id, + "__messages__": messages, + "__files__": files, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__user__": __user__, + } + + file_handler = hasattr(toolkit_module, "file_handler") + + if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) + toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) try: response = await generate_chat_completions(form_data=payload, user=user) @@ -353,85 +413,21 @@ async def get_function_call_response( return None, None, False # Parse the function response - print(f"content: {content}") + log.debug(f"content: {content}") result = json.loads(content) - print(result) - if "name" not in result: - return None, None, False + function_result = await call_tool_from_completion( + result, extra_params, toolkit_module + ) - # Call the function - if tool_id in webui_app.state.TOOLS: - toolkit_module = webui_app.state.TOOLS[tool_id] - else: - toolkit_module, _ = load_toolkit_module_by_id(tool_id) - webui_app.state.TOOLS[tool_id] = toolkit_module - - file_handler = False - # check if toolkit_module has file_handler self variable - if hasattr(toolkit_module, "file_handler"): - file_handler = True - print("file_handler: ", file_handler) - - if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) - toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {})) - - function = getattr(toolkit_module, result["name"]) - function_result = None - citation = None - try: - # Get the signature of the function - sig = inspect.signature(function) - params = result["parameters"] - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, + if hasattr(toolkit_module, "citation") and toolkit_module.citation: + citation = { + "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, + "document": [function_result], + "metadata": [{"source": result["name"]}], } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - # Call the function with the '__user__' parameter included - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(function): - function_result = await function(**params) - else: - function_result = function(**params) - - if hasattr(toolkit_module, "citation") and toolkit_module.citation: - citation = { - "source": {"name": f"TOOL:{tool.name}/{result['name']}"}, - "document": [function_result], - "metadata": [{"source": result["name"]}], - } - except Exception as e: - print(e) + else: + citation = None # Add the function result to the system prompt if function_result is not None: