diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index c83c2d120..56315c73f 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -40,6 +40,7 @@ from open_webui.retrieval.loaders.youtube import YoutubeLoader from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, OFFLINE_MODE, ENABLE_FORWARD_USER_INFO_HEADERS, AIOHTTP_CLIENT_SESSION_SSL, @@ -596,7 +597,9 @@ async def agenerate_openai_batch_embeddings( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.post( f"{url}/embeddings", headers=headers, json=form_data ) as r: @@ -685,7 +688,9 @@ async def agenerate_azure_openai_batch_embeddings( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.post(full_url, headers=headers, json=form_data) as r: r.raise_for_status() data = await r.json() @@ -761,7 +766,9 @@ async def agenerate_ollama_batch_embeddings( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.post( f"{url}/api/embed", headers=headers, diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 3ac9949eb..180d23c1c 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -6,6 +6,7 @@ import aiohttp from typing import Optional +from backend.open_webui.env import AIOHTTP_CLIENT_TIMEOUT from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import get_config, save_config from open_webui.config import BannerModel @@ -228,7 +229,10 @@ async def verify_tool_servers_config( log.debug( f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}" ) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT), + ) as session: async with session.get( discovery_url ) as oauth_server_metadata_response: diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index b6dd07d73..8c098b916 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -6,6 +6,7 @@ import aiohttp from pathlib import Path from typing import Optional +from backend.open_webui.env import AIOHTTP_CLIENT_TIMEOUT from open_webui.models.functions import ( FunctionForm, FunctionModel, @@ -39,12 +40,16 @@ router = APIRouter() @router.get("/", response_model=list[FunctionResponse]) -async def get_functions(user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_functions( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): return Functions.get_functions(db=db) @router.get("/list", response_model=list[FunctionUserResponse]) -async def get_function_list(user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def get_function_list( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): return Functions.get_function_list(db=db) @@ -54,7 +59,11 @@ async def get_function_list(user=Depends(get_admin_user), db: Session = Depends( @router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel]) -async def get_functions(include_valves: bool = False, user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def get_functions( + include_valves: bool = False, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): return Functions.get_functions(include_valves=include_valves, db=db) @@ -112,7 +121,9 @@ async def load_function_from_url( ) try: - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.get( url, headers={"Content-Type": "application/json"} ) as resp: @@ -144,7 +155,10 @@ class SyncFunctionsForm(BaseModel): @router.post("/sync", response_model=list[FunctionWithValvesModel]) async def sync_functions( - request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user), db: Session = Depends(get_session) + request: Request, + form_data: SyncFunctionsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: for function in form_data.functions: @@ -182,7 +196,10 @@ async def sync_functions( @router.post("/create", response_model=Optional[FunctionResponse]) async def create_new_function( - request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session) + request: Request, + form_data: FunctionForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): if not form_data.id.isidentifier(): raise HTTPException( @@ -205,13 +222,17 @@ async def create_new_function( FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function(user.id, function_type, form_data, db=db) + function = Functions.insert_new_function( + user.id, function_type, form_data, db=db + ) function_cache_dir = CACHE_DIR / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) if function_type == "filter" and getattr(function_module, "toggle", None): - Functions.update_function_metadata_by_id(form_data.id, {"toggle": True}, db=db) + Functions.update_function_metadata_by_id( + form_data.id, {"toggle": True}, db=db + ) if function: return function @@ -239,7 +260,9 @@ async def create_new_function( @router.get("/id/{id}", response_model=Optional[FunctionModel]) -async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def get_function_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): function = Functions.get_function_by_id(id, db=db) if function: @@ -257,7 +280,9 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user), db: Session @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) -async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def toggle_function_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): function = Functions.get_function_by_id(id, db=db) if function: function = Functions.update_function_by_id( @@ -284,7 +309,9 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Sessi @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) -async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def toggle_global_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): function = Functions.get_function_by_id(id, db=db) if function: function = Functions.update_function_by_id( @@ -312,7 +339,11 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: Session @router.post("/id/{id}/update", response_model=Optional[FunctionModel]) async def update_function_by_id( - request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db: Session = Depends(get_session) + request: Request, + id: str, + form_data: FunctionForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: form_data.content = replace_imports(form_data.content) @@ -354,7 +385,10 @@ async def update_function_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_function_by_id( - request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) + request: Request, + id: str, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): result = Functions.delete_function_by_id(id, db=db) @@ -372,7 +406,9 @@ async def delete_function_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: Session = Depends(get_session)): +async def get_function_valves_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): function = Functions.get_function_by_id(id, db=db) if function: try: @@ -397,7 +433,10 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: S @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) async def get_function_valves_spec_by_id( - request: Request, id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) + request: Request, + id: str, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): function = Functions.get_function_by_id(id, db=db) if function: @@ -423,7 +462,11 @@ async def get_function_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) async def update_function_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_admin_user), db: Session = Depends(get_session) + request: Request, + id: str, + form_data: dict, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): function = Functions.get_function_by_id(id, db=db) if function: @@ -466,11 +509,15 @@ async def update_function_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_function_user_valves_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): function = Functions.get_function_by_id(id, db=db) if function: try: - user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id, db=db) + user_valves = Functions.get_user_valves_by_id_and_user_id( + id, user.id, db=db + ) return user_valves except Exception as e: raise HTTPException( @@ -486,7 +533,10 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) async def get_function_user_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): function = Functions.get_function_by_id(id, db=db) if function: @@ -507,7 +557,11 @@ async def get_function_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) async def update_function_user_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) + request: Request, + id: str, + form_data: dict, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): function = Functions.get_function_by_id(id, db=db) diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index d919debc3..6dd156e56 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -54,7 +54,11 @@ def get_tool_module(request, tool_id, load_from_db=True): @router.get("/", response_model=list[ToolUserResponse]) -async def get_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_tools( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): tools = [] # Local Tools @@ -143,7 +147,9 @@ async def get_tools(request: Request, user=Depends(get_verified_user), db: Sessi # Admin can see all tools return tools else: - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } tools = [ tool for tool in tools @@ -159,7 +165,9 @@ async def get_tools(request: Request, user=Depends(get_verified_user), db: Sessi @router.get("/list", response_model=list[ToolAccessResponse]) -async def get_tool_list(user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_tool_list( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: tools = Tools.get_tools(db=db) else: @@ -232,7 +240,9 @@ async def load_tool_from_url( ) try: - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.get( url, headers={"Content-Type": "application/json"} ) as resp: @@ -259,9 +269,16 @@ async def load_tool_from_url( @router.get("/export", response_model=list[ToolModel]) -async def export_tools(request: Request, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def export_tools( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS, db=db + user.id, + "workspace.tools_export", + request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -291,7 +308,10 @@ async def create_new_tools( user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db ) or has_permission( - user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS, db=db + user.id, + "workspace.tools_import", + request.app.state.config.USER_PERMISSIONS, + db=db, ) ): raise HTTPException( @@ -351,7 +371,9 @@ async def create_new_tools( @router.get("/id/{id}", response_model=Optional[ToolAccessResponse]) -async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_tools_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): tools = Tools.get_tool_by_id(id, db=db) if tools: @@ -451,7 +473,10 @@ async def update_tools_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_tools_by_id( - request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): tools = Tools.get_tool_by_id(id, db=db) if not tools: @@ -485,7 +510,9 @@ async def delete_tools_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_tools_valves_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): tools = Tools.get_tool_by_id(id, db=db) if tools: try: @@ -510,7 +537,10 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: S @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) async def get_tools_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): tools = Tools.get_tool_by_id(id, db=db) if tools: @@ -538,7 +568,11 @@ async def get_tools_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) async def update_tools_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) + request: Request, + id: str, + form_data: dict, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): tools = Tools.get_tool_by_id(id, db=db) if not tools: @@ -590,7 +624,9 @@ async def update_tools_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: Session = Depends(get_session)): +async def get_tools_user_valves_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): tools = Tools.get_tool_by_id(id, db=db) if tools: try: @@ -610,7 +646,10 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) async def get_tools_user_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): tools = Tools.get_tool_by_id(id, db=db) if tools: @@ -633,7 +672,11 @@ async def get_tools_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) async def update_tools_user_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) + request: Request, + id: str, + form_data: dict, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): tools = Tools.get_tool_by_id(id, db=db) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index edbf86626..15dec7bed 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -923,6 +923,48 @@ def get_image_urls(delta_images, request, metadata, user) -> list[str]: return image_urls +def inject_file_context_into_messages(messages: list) -> None: + """ + Inject file context into each user message that has files. + Modifies messages in-place by prepending file info to message content. + """ + for message in messages: + if message.get("role") != "user": + continue + + files = message.get("files", []) + if not files: + continue + + # Build XML context for this message's files + file_entries = [] + for file in files: + if not file.get("url"): + continue + + attrs = [f'type="{file.get("type", "file")}"'] + if file.get("content_type"): + attrs.append(f'content_type="{file["content_type"]}"') + if file.get("name"): + attrs.append(f'name="{file["name"]}"') + attrs.append(f'url="{file["url"]}"') + file_entries.append(f'') + + if not file_entries: + continue + + files_context = "\n" + "\n".join(file_entries) + "\n\n\n" + + # Prepend to message content + content = message.get("content", "") + if isinstance(content, str): + message["content"] = files_context + content + elif isinstance(content, list): + # For multimodal content, prepend as text item + message["content"] = [{"type": "text", "text": files_context}] + content + + + async def chat_image_generation_handler( request: Request, form_data: dict, extra_params: dict, user ): @@ -1748,6 +1790,10 @@ async def process_chat_payload(request, form_data, user, metadata, model): {"type": "function", "function": tool.get("spec", {})} for tool in tools_dict.values() ] + # Inject file context into each user message that has files attached + inject_file_context_into_messages(form_data.get("messages", [])) + + else: # If the function calling is not native, then call the tools function calling handler try: diff --git a/backend/open_webui/utils/webhook.py b/backend/open_webui/utils/webhook.py index b617abc06..eb2688851 100644 --- a/backend/open_webui/utils/webhook.py +++ b/backend/open_webui/utils/webhook.py @@ -3,7 +3,7 @@ import logging import aiohttp from open_webui.config import WEBUI_FAVICON_URL -from open_webui.env import VERSION +from open_webui.env import AIOHTTP_CLIENT_TIMEOUT, VERSION log = logging.getLogger(__name__) @@ -50,7 +50,9 @@ async def post_webhook(name: str, url: str, message: str, event_data: dict) -> b payload = {**event_data} log.debug(f"payload: {payload}") - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.post(url, json=payload) as r: r_text = await r.text() r.raise_for_status()