From 0c57980e72cd6180baaa09166be04468493aa22d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 8 Jun 2025 20:58:31 +0400 Subject: [PATCH] refac: tasks --- backend/open_webui/main.py | 29 ++++++++++++++++++++------ backend/open_webui/tasks.py | 21 +++++++------------ backend/open_webui/utils/middleware.py | 6 ++++-- backend/open_webui/utils/redis.py | 9 ++++++-- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e623d2813..0ccdee4a0 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -8,6 +8,7 @@ import shutil import sys import time import random +from uuid import uuid4 from contextlib import asynccontextmanager from urllib.parse import urlencode, parse_qs, urlparse @@ -432,6 +433,7 @@ from open_webui.utils.auth import ( from open_webui.utils.plugin import install_tool_and_function_dependencies from open_webui.utils.oauth import OAuthManager from open_webui.utils.security_headers import SecurityHeadersMiddleware +from open_webui.utils.redis import get_redis_connection from open_webui.tasks import ( list_task_ids_by_chat_id, @@ -485,7 +487,9 @@ https://github.com/open-webui/open-webui @asynccontextmanager async def lifespan(app: FastAPI): + app.state.instance_id = os.environ.get("INSTANCE_ID", str(uuid4())) start_logger() + if RESET_CONFIG_ON_START: reset_config() @@ -497,6 +501,13 @@ async def lifespan(app: FastAPI): log.info("Installing external dependencies of functions and tools...") install_tool_and_function_dependencies() + app.state.redis = get_redis_connection( + redis_url=REDIS_URL, + redis_sentinels=get_sentinels_from_env( + REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT + ), + ) + if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0: limiter = anyio.to_thread.current_default_thread_limiter() limiter.total_tokens = THREAD_POOL_SIZE @@ -516,10 +527,12 @@ app = FastAPI( oauth_manager = OAuthManager(app) +app.state.instance_id = None app.state.config = AppConfig( redis_url=REDIS_URL, redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), ) +app.state.redis = None app.state.WEBUI_NAME = WEBUI_NAME app.state.LICENSE_METADATA = None @@ -1386,26 +1399,30 @@ async def chat_action( @app.post("/api/tasks/stop/{task_id}") -async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)): +async def stop_task_endpoint( + request: Request, task_id: str, user=Depends(get_verified_user) +): try: - result = await stop_task(task_id) + result = await stop_task(request, task_id) return result except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) @app.get("/api/tasks") -async def list_tasks_endpoint(user=Depends(get_verified_user)): - return {"tasks": list_tasks()} +async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)): + return {"tasks": list_tasks(request)} @app.get("/api/tasks/chat/{chat_id}") -async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)): +async def list_tasks_by_chat_id_endpoint( + request: Request, chat_id: str, user=Depends(get_verified_user) +): chat = Chats.get_chat_by_id(chat_id) if chat is None or chat.user_id != user.id: return {"task_ids": []} - task_ids = list_task_ids_by_chat_id(chat_id) + task_ids = list_task_ids_by_chat_id(request, chat_id) print(f"Task IDs for chat {chat_id}: {task_ids}") return {"task_ids": task_ids} diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py index e575e6885..530a59083 100644 --- a/backend/open_webui/tasks.py +++ b/backend/open_webui/tasks.py @@ -2,13 +2,15 @@ import asyncio from typing import Dict from uuid import uuid4 +import json +from redis import Redis # A dictionary to keep track of active tasks tasks: Dict[str, asyncio.Task] = {} chat_tasks = {} -def cleanup_task(task_id: str, id=None): +def cleanup_task(request, task_id: str, id=None): """ Remove a completed or canceled task from the global `tasks` dictionary. """ @@ -21,7 +23,7 @@ def cleanup_task(task_id: str, id=None): chat_tasks.pop(id, None) -def create_task(coroutine, id=None): +def create_task(request, coroutine, id=None): """ Create a new asyncio task and add it to the global task dictionary. """ @@ -29,7 +31,7 @@ def create_task(coroutine, id=None): task = asyncio.create_task(coroutine) # Create the task # Add a done callback for cleanup - task.add_done_callback(lambda t: cleanup_task(task_id, id)) + task.add_done_callback(lambda t: cleanup_task(request, task_id, id)) tasks[task_id] = task # If an ID is provided, associate the task with that ID @@ -41,28 +43,21 @@ def create_task(coroutine, id=None): return task_id, task -def get_task(task_id: str): - """ - Retrieve a task by its task ID. - """ - return tasks.get(task_id) - - -def list_tasks(): +def list_tasks(request): """ List all currently active task IDs. """ return list(tasks.keys()) -def list_task_ids_by_chat_id(id): +def list_task_ids_by_chat_id(request, id): """ List all tasks associated with a specific ID. """ return chat_tasks.get(id, []) -async def stop_task(task_id: str): +async def stop_task(request, task_id: str): """ Cancel a running task and remove it from the global task list. """ diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index a07262fa9..a5a9b8e07 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -2266,7 +2266,9 @@ async def process_chat_response( if "data:image/png;base64" in line: image_url = "" # Extract base64 image data from the line - image_data, content_type = load_b64_image_data(line) + image_data, content_type = ( + load_b64_image_data(line) + ) if image_data is not None: image_url = upload_image( request, @@ -2412,7 +2414,7 @@ async def process_chat_response( # background_tasks.add_task(post_response_handler, response, events) task_id, _ = create_task( - post_response_handler(response, events), id=metadata["chat_id"] + request, post_response_handler(response, events), id=metadata["chat_id"] ) return {"status": True, "task_id": task_id} diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index e0a53e73d..85eae55b6 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -2,6 +2,7 @@ import socketio import redis from redis import asyncio as aioredis from urllib.parse import urlparse +from typing import Optional def parse_redis_service_url(redis_url): @@ -18,7 +19,9 @@ def parse_redis_service_url(redis_url): } -def get_redis_connection(redis_url, redis_sentinels, decode_responses=True): +def get_redis_connection( + redis_url, redis_sentinels, decode_responses=True +) -> Optional[redis.Redis]: if redis_sentinels: redis_config = parse_redis_service_url(redis_url) sentinel = redis.sentinel.Sentinel( @@ -32,9 +35,11 @@ def get_redis_connection(redis_url, redis_sentinels, decode_responses=True): # Get a master connection from Sentinel return sentinel.master_for(redis_config["service"]) - else: + elif redis_url: # Standard Redis connection return redis.Redis.from_url(redis_url, decode_responses=decode_responses) + else: + return None def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):