refac/fix: multi-replica stop task (response)

This commit is contained in:
Timothy Jaeryang Baek 2025-06-08 21:20:30 +04:00
parent 0c57980e72
commit d8d8380a78
2 changed files with 102 additions and 0 deletions

View File

@ -10,6 +10,7 @@ import time
import random import random
from uuid import uuid4 from uuid import uuid4
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from urllib.parse import urlencode, parse_qs, urlparse from urllib.parse import urlencode, parse_qs, urlparse
from pydantic import BaseModel from pydantic import BaseModel
@ -20,6 +21,7 @@ from aiocache import cached
import aiohttp import aiohttp
import anyio.to_thread import anyio.to_thread
import requests import requests
from redis import Redis
from fastapi import ( from fastapi import (
@ -436,6 +438,7 @@ from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import ( from open_webui.tasks import (
redis_task_command_listener,
list_task_ids_by_chat_id, list_task_ids_by_chat_id,
stop_task, stop_task,
list_tasks, list_tasks,
@ -508,6 +511,11 @@ async def lifespan(app: FastAPI):
), ),
) )
if isinstance(app.state.redis, Redis):
app.state.redis_task_command_listener = asyncio.create_task(
redis_task_command_listener(app)
)
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0: if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
limiter = anyio.to_thread.current_default_thread_limiter() limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = THREAD_POOL_SIZE limiter.total_tokens = THREAD_POOL_SIZE
@ -516,6 +524,9 @@ async def lifespan(app: FastAPI):
yield yield
if hasattr(app.state, "redis_task_command_listener"):
app.state.redis_task_command_listener.cancel()
app = FastAPI( app = FastAPI(
title="Open WebUI", title="Open WebUI",

View File

@ -4,16 +4,88 @@ from typing import Dict
from uuid import uuid4 from uuid import uuid4
import json import json
from redis import Redis from redis import Redis
from fastapi import Request
from typing import Dict, List, Optional
# A dictionary to keep track of active tasks # A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {} tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {} chat_tasks = {}
REDIS_TASKS_KEY = "open-webui:tasks"
REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
def is_redis(request: Request) -> bool:
# Called everywhere a request is available to check Redis
return hasattr(request.app.state, "redis") and isinstance(
request.app.state.redis, Redis
)
async def redis_task_command_listener(app):
redis: Redis = app.state.redis
pubsub = redis.pubsub()
await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
print("Subscribed to Redis task command channel")
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
command = json.loads(message["data"])
if command.get("action") == "stop":
task_id = command.get("task_id")
local_task = tasks.get(task_id)
if local_task:
local_task.cancel()
except Exception as e:
print(f"Error handling distributed task command: {e}")
### ------------------------------
### REDIS-ENABLED HANDLERS
### ------------------------------
def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline()
pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
if chat_id:
pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
pipe.execute()
def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id)
if chat_id:
pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0:
pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
pipe.execute()
def redis_list_tasks(redis: Redis) -> List[str]:
return list(redis.hkeys(REDIS_TASKS_KEY))
def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
def redis_send_command(redis: Redis, command: dict):
redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
def cleanup_task(request, task_id: str, id=None): def cleanup_task(request, task_id: str, id=None):
""" """
Remove a completed or canceled task from the global `tasks` dictionary. Remove a completed or canceled task from the global `tasks` dictionary.
""" """
if is_redis(request):
redis_cleanup_task(request.app.state.redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary # If an ID is provided, remove the task from the chat_tasks dictionary
@ -40,6 +112,9 @@ def create_task(request, coroutine, id=None):
else: else:
chat_tasks[id] = [task_id] chat_tasks[id] = [task_id]
if is_redis(request):
redis_save_task(request.app.state.redis, task_id, id)
return task_id, task return task_id, task
@ -47,6 +122,8 @@ def list_tasks(request):
""" """
List all currently active task IDs. List all currently active task IDs.
""" """
if is_redis(request):
return redis_list_tasks(request.app.state.redis)
return list(tasks.keys()) return list(tasks.keys())
@ -54,6 +131,8 @@ def list_task_ids_by_chat_id(request, id):
""" """
List all tasks associated with a specific ID. List all tasks associated with a specific ID.
""" """
if is_redis(request):
return redis_list_chat_tasks(request.app.state.redis, id)
return chat_tasks.get(id, []) return chat_tasks.get(id, [])
@ -61,6 +140,18 @@ async def stop_task(request, task_id: str):
""" """
Cancel a running task and remove it from the global task list. Cancel a running task and remove it from the global task list.
""" """
if is_redis(request):
# PUBSUB: All instances check if they have this task, and stop if so.
redis_send_command(
request.app.state.redis,
{
"action": "stop",
"task_id": task_id,
},
)
# Optionally check if task_id still in Redis a few moments later for feedback?
return {"status": True, "message": f"Stop signal sent for {task_id}"}
task = tasks.get(task_id) task = tasks.get(task_id)
if not task: if not task:
raise ValueError(f"Task with ID {task_id} not found.") raise ValueError(f"Task with ID {task_id} not found.")