refac: tasks

This commit is contained in:
Timothy Jaeryang Baek 2025-06-08 20:58:31 +04:00
parent 51fe33395b
commit 0c57980e72
4 changed files with 42 additions and 23 deletions

View File

@ -8,6 +8,7 @@ import shutil
import sys import sys
import time import time
import random import random
from uuid import uuid4
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from urllib.parse import urlencode, parse_qs, urlparse from urllib.parse import urlencode, parse_qs, urlparse
@ -432,6 +433,7 @@ from open_webui.utils.auth import (
from open_webui.utils.plugin import install_tool_and_function_dependencies from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import ( from open_webui.tasks import (
list_task_ids_by_chat_id, list_task_ids_by_chat_id,
@ -485,7 +487,9 @@ https://github.com/open-webui/open-webui
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
app.state.instance_id = os.environ.get("INSTANCE_ID", str(uuid4()))
start_logger() start_logger()
if RESET_CONFIG_ON_START: if RESET_CONFIG_ON_START:
reset_config() reset_config()
@ -497,6 +501,13 @@ async def lifespan(app: FastAPI):
log.info("Installing external dependencies of functions and tools...") log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies() install_tool_and_function_dependencies()
app.state.redis = get_redis_connection(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
),
)
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0: if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
limiter = anyio.to_thread.current_default_thread_limiter() limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = THREAD_POOL_SIZE limiter.total_tokens = THREAD_POOL_SIZE
@ -516,10 +527,12 @@ app = FastAPI(
oauth_manager = OAuthManager(app) oauth_manager = OAuthManager(app)
app.state.instance_id = None
app.state.config = AppConfig( app.state.config = AppConfig(
redis_url=REDIS_URL, redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
) )
app.state.redis = None
app.state.WEBUI_NAME = WEBUI_NAME app.state.WEBUI_NAME = WEBUI_NAME
app.state.LICENSE_METADATA = None app.state.LICENSE_METADATA = None
@ -1386,26 +1399,30 @@ async def chat_action(
@app.post("/api/tasks/stop/{task_id}") @app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)): async def stop_task_endpoint(
request: Request, task_id: str, user=Depends(get_verified_user)
):
try: try:
result = await stop_task(task_id) result = await stop_task(request, task_id)
return result return result
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@app.get("/api/tasks") @app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)): async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
return {"tasks": list_tasks()} return {"tasks": list_tasks(request)}
@app.get("/api/tasks/chat/{chat_id}") @app.get("/api/tasks/chat/{chat_id}")
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)): async def list_tasks_by_chat_id_endpoint(
request: Request, chat_id: str, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(chat_id) chat = Chats.get_chat_by_id(chat_id)
if chat is None or chat.user_id != user.id: if chat is None or chat.user_id != user.id:
return {"task_ids": []} return {"task_ids": []}
task_ids = list_task_ids_by_chat_id(chat_id) task_ids = list_task_ids_by_chat_id(request, chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}") print(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids} return {"task_ids": task_ids}

View File

@ -2,13 +2,15 @@
import asyncio import asyncio
from typing import Dict from typing import Dict
from uuid import uuid4 from uuid import uuid4
import json
from redis import Redis
# A dictionary to keep track of active tasks # A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {} tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {} chat_tasks = {}
def cleanup_task(task_id: str, id=None): def cleanup_task(request, task_id: str, id=None):
""" """
Remove a completed or canceled task from the global `tasks` dictionary. Remove a completed or canceled task from the global `tasks` dictionary.
""" """
@ -21,7 +23,7 @@ def cleanup_task(task_id: str, id=None):
chat_tasks.pop(id, None) chat_tasks.pop(id, None)
def create_task(coroutine, id=None): def create_task(request, coroutine, id=None):
""" """
Create a new asyncio task and add it to the global task dictionary. Create a new asyncio task and add it to the global task dictionary.
""" """
@ -29,7 +31,7 @@ def create_task(coroutine, id=None):
task = asyncio.create_task(coroutine) # Create the task task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup # Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id, id)) task.add_done_callback(lambda t: cleanup_task(request, task_id, id))
tasks[task_id] = task tasks[task_id] = task
# If an ID is provided, associate the task with that ID # If an ID is provided, associate the task with that ID
@ -41,28 +43,21 @@ def create_task(coroutine, id=None):
return task_id, task return task_id, task
def get_task(task_id: str): def list_tasks(request):
"""
Retrieve a task by its task ID.
"""
return tasks.get(task_id)
def list_tasks():
""" """
List all currently active task IDs. List all currently active task IDs.
""" """
return list(tasks.keys()) return list(tasks.keys())
def list_task_ids_by_chat_id(id): def list_task_ids_by_chat_id(request, id):
""" """
List all tasks associated with a specific ID. List all tasks associated with a specific ID.
""" """
return chat_tasks.get(id, []) return chat_tasks.get(id, [])
async def stop_task(task_id: str): async def stop_task(request, task_id: str):
""" """
Cancel a running task and remove it from the global task list. Cancel a running task and remove it from the global task list.
""" """

View File

@ -2266,7 +2266,9 @@ async def process_chat_response(
if "data:image/png;base64" in line: if "data:image/png;base64" in line:
image_url = "" image_url = ""
# Extract base64 image data from the line # Extract base64 image data from the line
image_data, content_type = load_b64_image_data(line) image_data, content_type = (
load_b64_image_data(line)
)
if image_data is not None: if image_data is not None:
image_url = upload_image( image_url = upload_image(
request, request,
@ -2412,7 +2414,7 @@ async def process_chat_response(
# background_tasks.add_task(post_response_handler, response, events) # background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task( task_id, _ = create_task(
post_response_handler(response, events), id=metadata["chat_id"] request, post_response_handler(response, events), id=metadata["chat_id"]
) )
return {"status": True, "task_id": task_id} return {"status": True, "task_id": task_id}

View File

@ -2,6 +2,7 @@ import socketio
import redis import redis
from redis import asyncio as aioredis from redis import asyncio as aioredis
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional
def parse_redis_service_url(redis_url): def parse_redis_service_url(redis_url):
@ -18,7 +19,9 @@ def parse_redis_service_url(redis_url):
} }
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True): def get_redis_connection(
redis_url, redis_sentinels, decode_responses=True
) -> Optional[redis.Redis]:
if redis_sentinels: if redis_sentinels:
redis_config = parse_redis_service_url(redis_url) redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel( sentinel = redis.sentinel.Sentinel(
@ -32,9 +35,11 @@ def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
# Get a master connection from Sentinel # Get a master connection from Sentinel
return sentinel.master_for(redis_config["service"]) return sentinel.master_for(redis_config["service"])
else: elif redis_url:
# Standard Redis connection # Standard Redis connection
return redis.Redis.from_url(redis_url, decode_responses=decode_responses) return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
else:
return None
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env): def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):