refac: chat requests

This commit is contained in:
Timothy Jaeryang Baek
2024-12-19 01:00:32 -08:00
parent ea0d507e23
commit 2be9e55545
11 changed files with 752 additions and 424 deletions

View File

@@ -117,7 +117,9 @@ async def generate_chat_completion(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
@@ -141,6 +143,7 @@ async def generate_chat_completion(
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)

View File

@@ -2,21 +2,31 @@ import time
import logging
import sys
import asyncio
from aiocache import cached
from typing import Any, Optional
import random
import json
import inspect
from uuid import uuid4
from fastapi import Request
from fastapi import BackgroundTasks
from starlette.responses import Response, StreamingResponse
from open_webui.models.chats import Chats
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.routers.tasks import generate_queries
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_chat_tags,
)
from open_webui.models.users import UserModel
@@ -33,6 +43,7 @@ from open_webui.utils.task import (
tools_function_calling_generation_template,
)
from open_webui.utils.misc import (
get_message_list,
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
@@ -41,6 +52,8 @@ from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.tasks import create_task
from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
from open_webui.constants import TASKS
@@ -504,28 +517,178 @@ async def process_chat_payload(request, form_data, metadata, user, model):
return form_data, events
async def process_chat_response(response, events, metadata):
async def process_chat_response(request, response, user, events, metadata, tasks):
if not isinstance(response, StreamingResponse):
return response
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
if not any(
content_type in response.headers["Content-Type"]
for content_type in ["text/event-stream", "application/x-ndjson"]
):
return response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
event_emitter = None
if "session_id" in metadata:
event_emitter = get_event_emitter(metadata)
for event in events:
yield wrap_item(json.dumps(event))
if event_emitter:
async for data in original_generator:
yield data
task_id = str(uuid4()) # Create a unique task ID.
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),
)
# Handle as a background task
async def post_response_handler(response, events):
try:
for event in events:
await event_emitter(
{
"type": "chat-completion",
"data": event,
}
)
content = ""
async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line
# Skip empty lines
if not data.strip():
continue
# "data: " is the prefix for each event
if not data.startswith("data: "):
continue
# Remove the prefix
data = data[len("data: ") :]
try:
data = json.loads(data)
value = (
data.get("choices", [])[0].get("delta", {}).get("content")
)
if value:
content = f"{content}{value}"
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": content,
},
)
except Exception as e:
done = "data: [DONE]" in line
if done:
data = {"done": True}
else:
continue
await event_emitter(
{
"type": "chat-completion",
"data": data,
}
)
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
message = message_map.get(metadata["message_id"])
if message:
messages = get_message_list(message_map, message.get("id"))
if TASKS.TITLE_GENERATION in tasks:
res = await generate_title(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res:
title = (
res.get("choices", [])[0]
.get("message", {})
.get("content", message.get("content", "New Chat"))
)
Chats.update_chat_title_by_id(metadata["chat_id"], title)
await event_emitter(
{
"type": "chat-title",
"data": title,
}
)
if TASKS.TAGS_GENERATION in tasks:
res = await generate_chat_tags(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res:
tags_string = (
res.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
tags_string = tags_string[
tags_string.find("{") : tags_string.rfind("}") + 1
]
try:
tags = json.loads(tags_string).get("tags", [])
Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user
)
await event_emitter(
{
"type": "chat-tags",
"data": tags,
}
)
except Exception as e:
print(f"Error: {e}")
except asyncio.CancelledError:
print("Task was cancelled!")
await event_emitter({"type": "task-cancelled"})
if response.background is not None:
await response.background()
# background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task(post_response_handler(response, events))
return {"status": True, "task_id": task_id}
else:
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n"
for event in events:
yield wrap_item(json.dumps(event))
async for data in original_generator:
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),
)

View File

@@ -7,6 +7,34 @@ from pathlib import Path
from typing import Callable, Optional
def get_message_list(messages, message_id):
"""
Reconstructs a list of messages in order up to the specified message_id.
:param message_id: ID of the message to reconstruct the chain
:param messages: Message history dict containing all messages
:return: List of ordered messages starting from the root to the given message
"""
# Find the message by its id
current_message = messages.get(message_id)
if not current_message:
return f"Message ID {message_id} not found in the history."
# Reconstruct the chain by following the parentId links
message_list = []
while current_message:
message_list.insert(
0, current_message
) # Insert the message at the beginning of the list
parent_id = current_message["parentId"]
current_message = messages.get(parent_id) if parent_id else None
return message_list
def get_messages_content(messages: list[dict]) -> str:
return "\n".join(
[