diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index dddf3fbb2..06c1a0921 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -26,6 +26,7 @@ from utils.misc import ( apply_model_system_prompt_to_body, ) +from utils.tools import get_tools from config import ( SHOW_ADMIN_DETAILS, @@ -47,6 +48,7 @@ from config import ( OAUTH_USERNAME_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_EMAIL_CLAIM, + ENABLE_TOOLS_FILTER, ) from apps.socket.main import get_event_call, get_event_emitter @@ -271,7 +273,7 @@ def get_function_params(function_module, form_data, user, extra_params={}): return params -async def generate_function_chat_completion(form_data, user): +async def generate_function_chat_completion(form_data, user, files, tool_ids): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) metadata = form_data.pop("metadata", None) @@ -286,6 +288,21 @@ async def generate_function_chat_completion(form_data, user): __event_call__ = get_event_call(metadata) __task__ = metadata.get("task", None) + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + } + if not ENABLE_TOOLS_FILTER: + tools_params = { + **extra_params, + "__model__": app.state.MODELS[form_data["model"]], + "__messages__": form_data["messages"], + "__files__": files, + } + configured_tools = get_tools(app, tool_ids, user, tools_params) + + extra_params["__tools__"] = configured_tools if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id @@ -298,16 +315,7 @@ async def generate_function_chat_completion(form_data, user): function_module = get_function_module(pipe_id) pipe = function_module.pipe - params = get_function_params( - function_module, - form_data, - user, - { - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__task__": __task__, - }, - ) + params = get_function_params(function_module, form_data, user, extra_params) if form_data["stream"]: diff --git a/backend/main.py b/backend/main.py index 49984e9bc..0bd49ef3a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -994,13 +994,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u detail="Model not found", ) model = app.state.MODELS[model_id] + files = form_data.pop("files", None) + tool_ids = form_data.pop("tool_ids", None) if model.get("pipe"): - return await generate_function_chat_completion(form_data, user=user) - - for key in ["tool_ids", "files"]: - if key in form_data: - del form_data[key] + return await generate_function_chat_completion(form_data, user, files, tool_ids) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: