mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: chat requests
This commit is contained in:
@@ -30,7 +30,9 @@ from fastapi import (
|
||||
UploadFile,
|
||||
status,
|
||||
applications,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -295,6 +297,7 @@ from open_webui.utils.auth import (
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
@@ -822,11 +825,11 @@ async def chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user=Depends(get_verified_user),
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request)
|
||||
|
||||
tasks = form_data.pop("background_tasks", None)
|
||||
try:
|
||||
model_id = form_data.get("model", None)
|
||||
if model_id not in request.app.state.MODELS:
|
||||
@@ -834,13 +837,14 @@ async def chat_completion(
|
||||
model = request.app.state.MODELS[model_id]
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
metadata = {
|
||||
"user_id": user.id,
|
||||
"chat_id": form_data.pop("chat_id", None),
|
||||
"message_id": form_data.pop("id", None),
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
@@ -859,10 +863,10 @@ async def chat_completion(
|
||||
)
|
||||
|
||||
try:
|
||||
response = await chat_completion_handler(
|
||||
request, form_data, user, bypass_filter
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
return await process_chat_response(
|
||||
request, response, user, events, metadata, tasks
|
||||
)
|
||||
return await process_chat_response(response, events, metadata)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -901,6 +905,20 @@ async def chat_action(
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/tasks/stop/{task_id}")
|
||||
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
|
||||
try:
|
||||
result = await stop_task(task_id) # Use the function from tasks.py
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/tasks")
|
||||
async def list_tasks_endpoint(user=Depends(get_verified_user)):
|
||||
return {"tasks": list_tasks()} # Use the function from tasks.py
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Config Endpoints
|
||||
|
||||
@@ -168,6 +168,66 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
chat["title"] = title
|
||||
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def update_chat_tags_by_id(
|
||||
self, id: str, tags: list[str], user
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
self.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
for tag_name in tags:
|
||||
if tag_name.lower() == "none":
|
||||
continue
|
||||
|
||||
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
|
||||
return self.get_chat_by_id(id)
|
||||
|
||||
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("history", {}).get("messages", {}) or {}
|
||||
|
||||
def upsert_message_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, message: dict
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
history = chat.get("history", {})
|
||||
|
||||
if message_id in history.get("messages", {}):
|
||||
history["messages"][message_id] = {
|
||||
**history["messages"][message_id],
|
||||
**message,
|
||||
}
|
||||
else:
|
||||
history["messages"][message_id] = message
|
||||
|
||||
history["currentId"] = message_id
|
||||
|
||||
chat["history"] = history
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
# Get the existing chat to share
|
||||
|
||||
@@ -82,6 +82,16 @@ async def send_get_request(url, key=None):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def send_post_request(
|
||||
url: str,
|
||||
payload: Union[str, bytes],
|
||||
@@ -89,14 +99,6 @@ async def send_post_request(
|
||||
key: Optional[str] = None,
|
||||
content_type: Optional[str] = None,
|
||||
):
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
||||
@@ -217,15 +217,19 @@ async def disconnect(sid):
|
||||
|
||||
def get_event_emitter(request_info):
|
||||
async def __event_emitter__(event_data):
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"data": event_data,
|
||||
},
|
||||
to=request_info["session_id"],
|
||||
)
|
||||
user_id = request_info["user_id"]
|
||||
session_ids = USER_POOL.get(user_id, [])
|
||||
|
||||
for session_id in session_ids:
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"data": event_data,
|
||||
},
|
||||
to=session_id,
|
||||
)
|
||||
|
||||
return __event_emitter__
|
||||
|
||||
|
||||
61
backend/open_webui/tasks.py
Normal file
61
backend/open_webui/tasks.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# tasks.py
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from uuid import uuid4
|
||||
|
||||
# A dictionary to keep track of active tasks
|
||||
tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
def cleanup_task(task_id: str):
|
||||
"""
|
||||
Remove a completed or canceled task from the global `tasks` dictionary.
|
||||
"""
|
||||
tasks.pop(task_id, None) # Remove the task if it exists
|
||||
|
||||
|
||||
def create_task(coroutine):
|
||||
"""
|
||||
Create a new asyncio task and add it to the global task dictionary.
|
||||
"""
|
||||
task_id = str(uuid4()) # Generate a unique ID for the task
|
||||
task = asyncio.create_task(coroutine) # Create the task
|
||||
|
||||
# Add a done callback for cleanup
|
||||
task.add_done_callback(lambda t: cleanup_task(task_id))
|
||||
|
||||
tasks[task_id] = task
|
||||
return task_id, task
|
||||
|
||||
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
Retrieve a task by its task ID.
|
||||
"""
|
||||
return tasks.get(task_id)
|
||||
|
||||
|
||||
def list_tasks():
|
||||
"""
|
||||
List all currently active task IDs.
|
||||
"""
|
||||
return list(tasks.keys())
|
||||
|
||||
|
||||
async def stop_task(task_id: str):
|
||||
"""
|
||||
Cancel a running task and remove it from the global task list.
|
||||
"""
|
||||
task = tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task with ID {task_id} not found.")
|
||||
|
||||
task.cancel() # Request task cancellation
|
||||
try:
|
||||
await task # Wait for the task to handle the cancellation
|
||||
except asyncio.CancelledError:
|
||||
# Task successfully canceled
|
||||
tasks.pop(task_id, None) # Remove it from the dictionary
|
||||
return {"status": True, "message": f"Task {task_id} successfully stopped."}
|
||||
|
||||
return {"status": False, "message": f"Failed to stop task {task_id}."}
|
||||
@@ -117,7 +117,9 @@ async def generate_chat_completion(
|
||||
form_data, user, bypass_filter=True
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator), media_type="text/event-stream"
|
||||
stream_wrapper(response.body_iterator),
|
||||
media_type="text/event-stream",
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
@@ -141,6 +143,7 @@ async def generate_chat_completion(
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
background=response.background,
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
|
||||
@@ -2,21 +2,31 @@ import time
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import asyncio
|
||||
from aiocache import cached
|
||||
from typing import Any, Optional
|
||||
import random
|
||||
import json
|
||||
import inspect
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
from open_webui.routers.tasks import generate_queries
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
generate_title,
|
||||
generate_chat_tags,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
@@ -33,6 +43,7 @@ from open_webui.utils.task import (
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
get_message_list,
|
||||
add_or_update_system_message,
|
||||
get_last_user_message,
|
||||
prepend_to_first_user_message_content,
|
||||
@@ -41,6 +52,8 @@ from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
|
||||
from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
from open_webui.constants import TASKS
|
||||
@@ -504,28 +517,178 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
return form_data, events
|
||||
|
||||
|
||||
async def process_chat_response(response, events, metadata):
|
||||
async def process_chat_response(request, response, user, events, metadata, tasks):
|
||||
if not isinstance(response, StreamingResponse):
|
||||
return response
|
||||
|
||||
content_type = response.headers["Content-Type"]
|
||||
is_openai = "text/event-stream" in content_type
|
||||
is_ollama = "application/x-ndjson" in content_type
|
||||
|
||||
if not is_openai and not is_ollama:
|
||||
if not any(
|
||||
content_type in response.headers["Content-Type"]
|
||||
for content_type in ["text/event-stream", "application/x-ndjson"]
|
||||
):
|
||||
return response
|
||||
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
||||
event_emitter = None
|
||||
if "session_id" in metadata:
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
|
||||
for event in events:
|
||||
yield wrap_item(json.dumps(event))
|
||||
if event_emitter:
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
# Handle as a background task
|
||||
async def post_response_handler(response, events):
|
||||
try:
|
||||
for event in events:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat-completion",
|
||||
"data": event,
|
||||
}
|
||||
)
|
||||
|
||||
content = ""
|
||||
async for line in response.body_iterator:
|
||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||
data = line
|
||||
|
||||
# Skip empty lines
|
||||
if not data.strip():
|
||||
continue
|
||||
|
||||
# "data: " is the prefix for each event
|
||||
if not data.startswith("data: "):
|
||||
continue
|
||||
|
||||
# Remove the prefix
|
||||
data = data[len("data: ") :]
|
||||
|
||||
try:
|
||||
data = json.loads(data)
|
||||
value = (
|
||||
data.get("choices", [])[0].get("delta", {}).get("content")
|
||||
)
|
||||
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
done = "data: [DONE]" in line
|
||||
|
||||
if done:
|
||||
data = {"done": True}
|
||||
else:
|
||||
continue
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat-completion",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
|
||||
message = message_map.get(metadata["message_id"])
|
||||
|
||||
if message:
|
||||
messages = get_message_list(message_map, message.get("id"))
|
||||
|
||||
if TASKS.TITLE_GENERATION in tasks:
|
||||
res = await generate_title(
|
||||
request,
|
||||
{
|
||||
"model": message["model"],
|
||||
"messages": messages,
|
||||
"chat_id": metadata["chat_id"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if res:
|
||||
title = (
|
||||
res.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", message.get("content", "New Chat"))
|
||||
)
|
||||
|
||||
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat-title",
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
|
||||
if TASKS.TAGS_GENERATION in tasks:
|
||||
res = await generate_chat_tags(
|
||||
request,
|
||||
{
|
||||
"model": message["model"],
|
||||
"messages": messages,
|
||||
"chat_id": metadata["chat_id"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if res:
|
||||
tags_string = (
|
||||
res.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
|
||||
tags_string = tags_string[
|
||||
tags_string.find("{") : tags_string.rfind("}") + 1
|
||||
]
|
||||
|
||||
try:
|
||||
tags = json.loads(tags_string).get("tags", [])
|
||||
Chats.update_chat_tags_by_id(
|
||||
metadata["chat_id"], tags, user
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat-tags",
|
||||
"data": tags,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print("Task was cancelled!")
|
||||
await event_emitter({"type": "task-cancelled"})
|
||||
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
|
||||
# background_tasks.add_task(post_response_handler, response, events)
|
||||
task_id, _ = create_task(post_response_handler(response, events))
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
else:
|
||||
# Fallback to the original response
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n"
|
||||
|
||||
for event in events:
|
||||
yield wrap_item(json.dumps(event))
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
@@ -7,6 +7,34 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def get_message_list(messages, message_id):
|
||||
"""
|
||||
Reconstructs a list of messages in order up to the specified message_id.
|
||||
|
||||
:param message_id: ID of the message to reconstruct the chain
|
||||
:param messages: Message history dict containing all messages
|
||||
:return: List of ordered messages starting from the root to the given message
|
||||
"""
|
||||
|
||||
# Find the message by its id
|
||||
current_message = messages.get(message_id)
|
||||
|
||||
if not current_message:
|
||||
return f"Message ID {message_id} not found in the history."
|
||||
|
||||
# Reconstruct the chain by following the parentId links
|
||||
message_list = []
|
||||
|
||||
while current_message:
|
||||
message_list.insert(
|
||||
0, current_message
|
||||
) # Insert the message at the beginning of the list
|
||||
parent_id = current_message["parentId"]
|
||||
current_message = messages.get(parent_id) if parent_id else None
|
||||
|
||||
return message_list
|
||||
|
||||
|
||||
def get_messages_content(messages: list[dict]) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user