refac/fix: multi-replica tasks

This commit is contained in:
Timothy Jaeryang Baek 2025-06-09 17:21:10 +04:00
parent 5fae7360fb
commit ea8dc333ee
3 changed files with 30 additions and 30 deletions

View File

@ -513,7 +513,7 @@ async def lifespan(app: FastAPI):
async_mode=True, async_mode=True,
) )
if isinstance(app.state.redis, Redis): if app.state.redis is not None:
app.state.redis_task_command_listener = asyncio.create_task( app.state.redis_task_command_listener = asyncio.create_task(
redis_task_command_listener(app) redis_task_command_listener(app)
) )
@ -1424,7 +1424,7 @@ async def stop_task_endpoint(
@app.get("/api/tasks") @app.get("/api/tasks")
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)): async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
return {"tasks": list_tasks(request)} return {"tasks": await list_tasks(request)}
@app.get("/api/tasks/chat/{chat_id}") @app.get("/api/tasks/chat/{chat_id}")
@ -1435,7 +1435,7 @@ async def list_tasks_by_chat_id_endpoint(
if chat is None or chat.user_id != user.id: if chat is None or chat.user_id != user.id:
return {"task_ids": []} return {"task_ids": []}
task_ids = list_task_ids_by_chat_id(request, chat_id) task_ids = await list_task_ids_by_chat_id(request, chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}") print(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids} return {"task_ids": task_ids}

View File

@ -3,7 +3,7 @@ import asyncio
from typing import Dict from typing import Dict
from uuid import uuid4 from uuid import uuid4
import json import json
from redis import Redis from redis.asyncio import Redis
from fastapi import Request from fastapi import Request
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -19,18 +19,16 @@ REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
def is_redis(request: Request) -> bool: def is_redis(request: Request) -> bool:
# Called everywhere a request is available to check Redis # Called everywhere a request is available to check Redis
return hasattr(request.app.state, "redis") and isinstance( return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
request.app.state.redis, Redis
)
async def redis_task_command_listener(app): async def redis_task_command_listener(app):
redis: Redis = app.state.redis redis: Redis = app.state.redis
pubsub = redis.pubsub() pubsub = redis.pubsub()
await pubsub.subscribe(REDIS_PUBSUB_CHANNEL) await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
print("Subscribed to Redis task command channel")
async for message in pubsub.listen(): async for message in pubsub.listen():
print(f"Received message: {message}")
if message["type"] != "message": if message["type"] != "message":
continue continue
try: try:
@ -49,42 +47,42 @@ async def redis_task_command_listener(app):
### ------------------------------ ### ------------------------------
def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]): async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline() pipe = redis.pipeline()
pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "") pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
if chat_id: if chat_id:
pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id) pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
pipe.execute() await pipe.execute()
def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]): async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline() pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id) pipe.hdel(REDIS_TASKS_KEY, task_id)
if chat_id: if chat_id:
pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id) pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0: if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
pipe.execute() await pipe.execute()
def redis_list_tasks(redis: Redis) -> List[str]: async def redis_list_tasks(redis: Redis) -> List[str]:
return list(redis.hkeys(REDIS_TASKS_KEY)) return list(await redis.hkeys(REDIS_TASKS_KEY))
def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]: async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")) return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
def redis_send_command(redis: Redis, command: dict): async def redis_send_command(redis: Redis, command: dict):
redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command)) await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
def cleanup_task(request, task_id: str, id=None): async def cleanup_task(request, task_id: str, id=None):
""" """
Remove a completed or canceled task from the global `tasks` dictionary. Remove a completed or canceled task from the global `tasks` dictionary.
""" """
if is_redis(request): if is_redis(request):
redis_cleanup_task(request.app.state.redis, task_id, id) await redis_cleanup_task(request.app.state.redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists tasks.pop(task_id, None) # Remove the task if it exists
@ -95,7 +93,7 @@ def cleanup_task(request, task_id: str, id=None):
chat_tasks.pop(id, None) chat_tasks.pop(id, None)
def create_task(request, coroutine, id=None): async def create_task(request, coroutine, id=None):
""" """
Create a new asyncio task and add it to the global task dictionary. Create a new asyncio task and add it to the global task dictionary.
""" """
@ -103,7 +101,9 @@ def create_task(request, coroutine, id=None):
task = asyncio.create_task(coroutine) # Create the task task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup # Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(request, task_id, id)) task.add_done_callback(
lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
)
tasks[task_id] = task tasks[task_id] = task
# If an ID is provided, associate the task with that ID # If an ID is provided, associate the task with that ID
@ -113,26 +113,26 @@ def create_task(request, coroutine, id=None):
chat_tasks[id] = [task_id] chat_tasks[id] = [task_id]
if is_redis(request): if is_redis(request):
redis_save_task(request.app.state.redis, task_id, id) await redis_save_task(request.app.state.redis, task_id, id)
return task_id, task return task_id, task
def list_tasks(request): async def list_tasks(request):
""" """
List all currently active task IDs. List all currently active task IDs.
""" """
if is_redis(request): if is_redis(request):
return redis_list_tasks(request.app.state.redis) return await redis_list_tasks(request.app.state.redis)
return list(tasks.keys()) return list(tasks.keys())
def list_task_ids_by_chat_id(request, id): async def list_task_ids_by_chat_id(request, id):
""" """
List all tasks associated with a specific ID. List all tasks associated with a specific ID.
""" """
if is_redis(request): if is_redis(request):
return redis_list_chat_tasks(request.app.state.redis, id) return await redis_list_chat_tasks(request.app.state.redis, id)
return chat_tasks.get(id, []) return chat_tasks.get(id, [])
@ -142,7 +142,7 @@ async def stop_task(request, task_id: str):
""" """
if is_redis(request): if is_redis(request):
# PUBSUB: All instances check if they have this task, and stop if so. # PUBSUB: All instances check if they have this task, and stop if so.
redis_send_command( await redis_send_command(
request.app.state.redis, request.app.state.redis,
{ {
"action": "stop", "action": "stop",

View File

@ -2413,7 +2413,7 @@ async def process_chat_response(
await response.background() await response.background()
# background_tasks.add_task(post_response_handler, response, events) # background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task( task_id, _ = await create_task(
request, post_response_handler(response, events), id=metadata["chat_id"] request, post_response_handler(response, events), id=metadata["chat_id"]
) )
return {"status": True, "task_id": task_id} return {"status": True, "task_id": task_id}