diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 6741f2d10..1a055f327 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -143,10 +143,10 @@ class FunctionsTable: for function in Function.select().where(Function.type == type) ] - def get_function_valves_by_id(self, id: str) -> Optional[FunctionValves]: + def get_function_valves_by_id(self, id: str) -> Optional[dict]: try: function = Function.get(Function.id == id) - return FunctionValves(**model_to_dict(function)) + return function.valves if "valves" in function and function.valves else {} except Exception as e: print(f"An error occurred: {e}") return None diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 0f5755e39..41504bd4a 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -114,10 +114,10 @@ class ToolsTable: def get_tools(self) -> List[ToolModel]: return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] - def get_tool_valves_by_id(self, id: str) -> Optional[ToolValves]: + def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: tool = Tool.get(Tool.id == id) - return ToolValves(**model_to_dict(tool)) + return tool.valves if "valves" in tool and tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index 4a0e7c564..fa3e3aeb9 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -127,8 +127,8 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): function = Functions.get_function_by_id(id) if function: try: - function_valves = Functions.get_function_valves_by_id(id) - return function_valves.valves + valves = Functions.get_function_valves_by_id(id) + return valves except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 3ab75187a..ab974391c 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -133,8 +133,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): toolkit = Tools.get_tool_by_id(id) if toolkit: try: - tool_valves = Tools.get_tool_valves_by_id(id) - return tool_valves.valves + valves = Tools.get_tool_valves_by_id(id) + return valves except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/backend/main.py b/backend/main.py index 2a44d2029..991eb5839 100644 --- a/backend/main.py +++ b/backend/main.py @@ -262,6 +262,13 @@ async def get_function_call_response( file_handler = True print("file_handler: ", file_handler) + if hasattr(toolkit_module, "valves") and hasattr( + toolkit_module, "Valves" + ): + toolkit_module.valves = toolkit_module.Valves( + **Tools.get_tool_valves_by_id(tool_id) + ) + function = getattr(toolkit_module, result["name"]) function_result = None try: @@ -402,6 +409,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if hasattr(function_module, "file_handler"): skip_files = function_module.file_handler + if hasattr(function_module, "valves") and hasattr( + function_module, "Valves" + ): + function_module.valves = function_module.Valves( + **Functions.get_function_valves_by_id(filter_id) + ) + try: if hasattr(function_module, "inlet"): inlet = function_module.inlet @@ -884,6 +898,13 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u else: function_module = webui_app.state.FUNCTIONS[pipe_id] + if hasattr(function_module, "valves") and hasattr( + function_module, "Valves" + ): + function_module.valves = function_module.Valves( + **Functions.get_function_valves_by_id(pipe_id) + ) + pipe = function_module.pipe # Get the signature of the function @@ -1105,6 +1126,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ) webui_app.state.FUNCTIONS[filter_id] = function_module + if hasattr(function_module, "valves") and hasattr( + function_module, "Valves" + ): + function_module.valves = function_module.Valves( + **Functions.get_function_valves_by_id(filter_id) + ) + try: if hasattr(function_module, "outlet"): outlet = function_module.outlet