From 14c3d0c2d14a49da7942863430e11920fdec1598 Mon Sep 17 00:00:00 2001 From: Gunwoo Hur Date: Tue, 27 May 2025 18:08:58 +0900 Subject: [PATCH] Prevent duplicate function module loads with caching helper and refactor --- backend/open_webui/functions.py | 9 +++++---- backend/open_webui/routers/functions.py | 18 +++++++++--------- backend/open_webui/utils/chat.py | 8 +++++--- backend/open_webui/utils/filter.py | 10 +++++----- backend/open_webui/utils/models.py | 8 +++++--- backend/open_webui/utils/plugin.py | 18 ++++++++++++++++++ 6 files changed, 47 insertions(+), 24 deletions(-) diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index aa7dbccf9..6d8203839 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -28,7 +28,10 @@ from open_webui.socket.main import ( from open_webui.models.functions import Functions from open_webui.models.models import Models -from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.plugin import ( + load_function_module_by_id, + get_function_module_from_cache, +) from open_webui.utils.tools import get_tools from open_webui.utils.access_control import has_access @@ -53,9 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) def get_function_module_by_id(request: Request, pipe_id: str): - # Check if function is already loaded - function_module, _, _ = load_function_module_by_id(pipe_id) - request.app.state.FUNCTIONS[pipe_id] = function_module + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(pipe_id) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 2748fa95c..f274dffea 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -12,7 +12,11 @@ from open_webui.models.functions import ( FunctionResponse, Functions, ) -from open_webui.utils.plugin import load_function_module_by_id, replace_imports +from open_webui.utils.plugin import ( + load_function_module_by_id, + replace_imports, + get_function_module_from_cache, +) from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status @@ -358,8 +362,7 @@ async def get_function_valves_spec_by_id( ): function = Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) if hasattr(function_module, "Valves"): Valves = function_module.Valves @@ -383,8 +386,7 @@ async def update_function_valves_by_id( ): function = Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) if hasattr(function_module, "Valves"): Valves = function_module.Valves @@ -443,8 +445,7 @@ async def get_function_user_valves_spec_by_id( ): function = Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) if hasattr(function_module, "UserValves"): UserValves = function_module.UserValves @@ -464,8 +465,7 @@ async def update_function_user_valves_by_id( function = Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache(request, id) if hasattr(function_module, "UserValves"): UserValves = function_module.UserValves diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index d846e35b6..4bd744e3c 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -40,7 +40,10 @@ from open_webui.models.functions import Functions from open_webui.models.models import Models -from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.plugin import ( + load_function_module_by_id, + get_function_module_from_cache, +) from open_webui.utils.models import get_all_models, check_model_access from open_webui.utils.payload import convert_payload_openai_to_ollama from open_webui.utils.response import ( @@ -392,8 +395,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A } ) - function_module, _, _ = load_function_module_by_id(action_id) - request.app.state.FUNCTIONS[action_id] = function_module + function_module, _, _ = get_function_module_from_cache(request, action_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): valves = Functions.get_function_valves_by_id(action_id) diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 8a4a7ba49..c9adee7d7 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -1,7 +1,10 @@ import inspect import logging -from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.plugin import ( + load_function_module_by_id, + get_function_module_from_cache, +) from open_webui.models.functions import Functions from open_webui.env import SRC_LOG_LEVELS @@ -13,10 +16,7 @@ def get_function_module(request, function_id): """ Get the function module by its ID. """ - - function_module, _, _ = load_function_module_by_id(function_id) - request.app.state.FUNCTIONS[function_id] = function_module - + function_module, _, _ = get_function_module_from_cache(request, function_id) return function_module diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 684d2074e..adb63f520 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -13,7 +13,10 @@ from open_webui.models.functions import Functions from open_webui.models.models import Models -from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.plugin import ( + load_function_module_by_id, + get_function_module_from_cache, +) from open_webui.utils.access_control import has_access @@ -239,8 +242,7 @@ async def get_all_models(request, user: UserModel = None): ] def get_function_module_by_id(function_id): - function_module, _, _ = load_function_module_by_id(function_id) - request.app.state.FUNCTIONS[function_id] = function_module + function_module, _, _ = get_function_module_from_cache(request, function_id) return function_module for model in models: diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 9c2ee1bbd..a7a4325e3 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -166,6 +166,24 @@ def load_function_module_by_id(function_id, content=None): os.unlink(temp_file.name) +def get_function_module_from_cache(request, function_id): + if ( + hasattr(request.app.state, "FUNCTIONS") + and function_id in request.app.state.FUNCTIONS + ): + return request.app.state.FUNCTIONS[function_id], None, None + + function_module, function_type, frontmatter = load_function_module_by_id( + function_id + ) + + if not hasattr(request.app.state, "FUNCTIONS"): + request.app.state.FUNCTIONS = {} + + request.app.state.FUNCTIONS[function_id] = function_module + return function_module, function_type, frontmatter + + def install_frontmatter_requirements(requirements: str): if requirements: try: