mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge pull request #14392 from hurxxxx/fix/improve-loading-functions
perf: Prevent duplicate function loads with caching helper and refactor
This commit is contained in:
commit
75b2e4a659
@ -28,7 +28,10 @@ from open_webui.socket.main import (
|
|||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
from open_webui.utils.plugin import (
|
||||||
|
load_function_module_by_id,
|
||||||
|
get_function_module_from_cache,
|
||||||
|
)
|
||||||
from open_webui.utils.tools import get_tools
|
from open_webui.utils.tools import get_tools
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
@ -53,9 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|||||||
|
|
||||||
|
|
||||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||||
# Check if function is already loaded
|
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
||||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
|
||||||
request.app.state.FUNCTIONS[pipe_id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||||
|
@ -12,7 +12,11 @@ from open_webui.models.functions import (
|
|||||||
FunctionResponse,
|
FunctionResponse,
|
||||||
Functions,
|
Functions,
|
||||||
)
|
)
|
||||||
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
|
from open_webui.utils.plugin import (
|
||||||
|
load_function_module_by_id,
|
||||||
|
replace_imports,
|
||||||
|
get_function_module_from_cache,
|
||||||
|
)
|
||||||
from open_webui.config import CACHE_DIR
|
from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
@ -358,8 +362,7 @@ async def get_function_valves_spec_by_id(
|
|||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
|
||||||
request.app.state.FUNCTIONS[id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "Valves"):
|
if hasattr(function_module, "Valves"):
|
||||||
Valves = function_module.Valves
|
Valves = function_module.Valves
|
||||||
@ -383,8 +386,7 @@ async def update_function_valves_by_id(
|
|||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
|
||||||
request.app.state.FUNCTIONS[id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "Valves"):
|
if hasattr(function_module, "Valves"):
|
||||||
Valves = function_module.Valves
|
Valves = function_module.Valves
|
||||||
@ -443,8 +445,7 @@ async def get_function_user_valves_spec_by_id(
|
|||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
|
||||||
request.app.state.FUNCTIONS[id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "UserValves"):
|
if hasattr(function_module, "UserValves"):
|
||||||
UserValves = function_module.UserValves
|
UserValves = function_module.UserValves
|
||||||
@ -464,8 +465,7 @@ async def update_function_user_valves_by_id(
|
|||||||
function = Functions.get_function_by_id(id)
|
function = Functions.get_function_by_id(id)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
|
||||||
request.app.state.FUNCTIONS[id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "UserValves"):
|
if hasattr(function_module, "UserValves"):
|
||||||
UserValves = function_module.UserValves
|
UserValves = function_module.UserValves
|
||||||
|
@ -40,7 +40,10 @@ from open_webui.models.functions import Functions
|
|||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
from open_webui.utils.plugin import (
|
||||||
|
load_function_module_by_id,
|
||||||
|
get_function_module_from_cache,
|
||||||
|
)
|
||||||
from open_webui.utils.models import get_all_models, check_model_access
|
from open_webui.utils.models import get_all_models, check_model_access
|
||||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||||
from open_webui.utils.response import (
|
from open_webui.utils.response import (
|
||||||
@ -392,8 +395,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
function_module, _, _ = load_function_module_by_id(action_id)
|
function_module, _, _ = get_function_module_from_cache(request, action_id)
|
||||||
request.app.state.FUNCTIONS[action_id] = function_module
|
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
valves = Functions.get_function_valves_by_id(action_id)
|
valves = Functions.get_function_valves_by_id(action_id)
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
from open_webui.utils.plugin import (
|
||||||
|
load_function_module_by_id,
|
||||||
|
get_function_module_from_cache,
|
||||||
|
)
|
||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
@ -13,10 +16,7 @@ def get_function_module(request, function_id):
|
|||||||
"""
|
"""
|
||||||
Get the function module by its ID.
|
Get the function module by its ID.
|
||||||
"""
|
"""
|
||||||
|
function_module, _, _ = get_function_module_from_cache(request, function_id)
|
||||||
function_module, _, _ = load_function_module_by_id(function_id)
|
|
||||||
request.app.state.FUNCTIONS[function_id] = function_module
|
|
||||||
|
|
||||||
return function_module
|
return function_module
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,10 @@ from open_webui.models.functions import Functions
|
|||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
from open_webui.utils.plugin import (
|
||||||
|
load_function_module_by_id,
|
||||||
|
get_function_module_from_cache,
|
||||||
|
)
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
|
|
||||||
@ -239,8 +242,7 @@ async def get_all_models(request, user: UserModel = None):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_function_module_by_id(function_id):
|
def get_function_module_by_id(function_id):
|
||||||
function_module, _, _ = load_function_module_by_id(function_id)
|
function_module, _, _ = get_function_module_from_cache(request, function_id)
|
||||||
request.app.state.FUNCTIONS[function_id] = function_module
|
|
||||||
return function_module
|
return function_module
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
|
@ -166,6 +166,24 @@ def load_function_module_by_id(function_id, content=None):
|
|||||||
os.unlink(temp_file.name)
|
os.unlink(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_module_from_cache(request, function_id):
|
||||||
|
if (
|
||||||
|
hasattr(request.app.state, "FUNCTIONS")
|
||||||
|
and function_id in request.app.state.FUNCTIONS
|
||||||
|
):
|
||||||
|
return request.app.state.FUNCTIONS[function_id], None, None
|
||||||
|
|
||||||
|
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||||
|
function_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(request.app.state, "FUNCTIONS"):
|
||||||
|
request.app.state.FUNCTIONS = {}
|
||||||
|
|
||||||
|
request.app.state.FUNCTIONS[function_id] = function_module
|
||||||
|
return function_module, function_type, frontmatter
|
||||||
|
|
||||||
|
|
||||||
def install_frontmatter_requirements(requirements: str):
|
def install_frontmatter_requirements(requirements: str):
|
||||||
if requirements:
|
if requirements:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user