diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index eb38d570c..40f264860 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -56,12 +56,15 @@ from apps.socket.main import get_event_call, get_event_emitter import inspect import json +import logging from typing import Iterator, Generator, AsyncGenerator from pydantic import BaseModel app = FastAPI() +log = logging.getLogger(__name__) + app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP @@ -243,44 +246,37 @@ def get_pipe_id(form_data: dict) -> str: return pipe_id -def get_function_params(function_module, form_data, user, extra_params={}): +def get_function_params(function_module, form_data, user, extra_params=None): + if extra_params is None: + extra_params = {} + pipe_id = get_pipe_id(form_data) + # Get the signature of the function sig = inspect.signature(function_module.pipe) - params = {"body": form_data} - - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + params = {"body": form_data} | { + k: v for k, v in extra_params.items() if k in sig.parameters + } + if "__user__" in params and hasattr(function_module, "UserValves"): + user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - ) + params["__user__"]["valves"] = function_module.UserValves(**user_valves) except Exception as e: - print(e) + log.exception(e) + params["__user__"]["valves"] = function_module.UserValves() - params["__user__"] = __user__ return params async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) + metadata = form_data.pop("metadata", {}) files = metadata.get("files", []) tool_ids = metadata.get("tool_ids", []) - # Check if tool_ids is None if tool_ids is None: tool_ids = [] @@ -299,18 +295,25 @@ async def generate_function_chat_completion(form_data, user): "__event_emitter__": __event_emitter__, "__event_call__": __event_call__, "__task__": __task__, - } - - extra_params["__tools__"] = get_tools( - app, - tool_ids, - user, - { - **extra_params, - "__model__": app.state.MODELS[form_data["model"]], - "__messages__": form_data["messages"], - "__files__": files, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, }, + } + extra_params["__tools__"] = ( + get_tools( + app, + tool_ids, + user, + { + **extra_params, + "__model__": app.state.MODELS[form_data["model"]], + "__messages__": form_data["messages"], + "__files__": files, + }, + ), ) if model_info: