refac: chat requests

This commit is contained in:
Timothy Jaeryang Baek 2024-12-19 01:00:32 -08:00
parent ea0d507e23
commit 2be9e55545
11 changed files with 752 additions and 424 deletions

View File

@ -30,7 +30,9 @@ from fastapi import (
UploadFile,
status,
applications,
BackgroundTasks,
)
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
@ -295,6 +297,7 @@ from open_webui.utils.auth import (
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
if SAFE_MODE:
print("SAFE MODE ENABLED")
@ -822,11 +825,11 @@ async def chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: bool = False,
):
if not request.app.state.MODELS:
await get_all_models(request)
tasks = form_data.pop("background_tasks", None)
try:
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
@ -834,13 +837,14 @@ async def chat_completion(
model = request.app.state.MODELS[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
metadata = {
"user_id": user.id,
"chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None),
@ -859,10 +863,10 @@ async def chat_completion(
)
try:
response = await chat_completion_handler(
request, form_data, user, bypass_filter
response = await chat_completion_handler(request, form_data, user)
return await process_chat_response(
request, response, user, events, metadata, tasks
)
return await process_chat_response(response, events, metadata)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -901,6 +905,20 @@ async def chat_action(
)
@app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
try:
result = await stop_task(task_id) # Use the function from tasks.py
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)):
return {"tasks": list_tasks()} # Use the function from tasks.py
##################################
#
# Config Endpoints

View File

@ -168,6 +168,66 @@ class ChatTable:
except Exception:
return None
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
chat["title"] = title
return self.update_chat_by_id(id, chat)
def update_chat_tags_by_id(
self, id: str, tags: list[str], user
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
self.delete_all_tags_by_id_and_user_id(id, user.id)
for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
for tag_name in tags:
if tag_name.lower() == "none":
continue
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
return self.get_chat_by_id(id)
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}) or {}
def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
history["messages"][message_id] = {
**history["messages"][message_id],
**message,
}
else:
history["messages"][message_id] = message
history["currentId"] = message_id
chat["history"] = history
return self.update_chat_by_id(id, chat)
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share

View File

@ -82,6 +82,16 @@ async def send_get_request(url, key=None):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def send_post_request(
url: str,
payload: Union[str, bytes],
@ -89,14 +99,6 @@ async def send_post_request(
key: Optional[str] = None,
content_type: Optional[str] = None,
):
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
r = None
try:

View File

@ -217,15 +217,19 @@ async def disconnect(sid):
def get_event_emitter(request_info):
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=request_info["session_id"],
)
user_id = request_info["user_id"]
session_ids = USER_POOL.get(user_id, [])
for session_id in session_ids:
await sio.emit(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=session_id,
)
return __event_emitter__

View File

@ -0,0 +1,61 @@
# tasks.py
import asyncio
from typing import Dict
from uuid import uuid4
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
def cleanup_task(task_id: str):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
tasks.pop(task_id, None) # Remove the task if it exists
def create_task(coroutine):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
task_id = str(uuid4()) # Generate a unique ID for the task
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id))
tasks[task_id] = task
return task_id, task
def get_task(task_id: str):
"""
Retrieve a task by its task ID.
"""
return tasks.get(task_id)
def list_tasks():
"""
List all currently active task IDs.
"""
return list(tasks.keys())
async def stop_task(task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
task = tasks.get(task_id)
if not task:
raise ValueError(f"Task with ID {task_id} not found.")
task.cancel() # Request task cancellation
try:
await task # Wait for the task to handle the cancellation
except asyncio.CancelledError:
# Task successfully canceled
tasks.pop(task_id, None) # Remove it from the dictionary
return {"status": True, "message": f"Task {task_id} successfully stopped."}
return {"status": False, "message": f"Failed to stop task {task_id}."}

View File

@ -117,7 +117,9 @@ async def generate_chat_completion(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
@ -141,6 +143,7 @@ async def generate_chat_completion(
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)

View File

@ -2,21 +2,31 @@ import time
import logging
import sys
import asyncio
from aiocache import cached
from typing import Any, Optional
import random
import json
import inspect
from uuid import uuid4
from fastapi import Request
from fastapi import BackgroundTasks
from starlette.responses import Response, StreamingResponse
from open_webui.models.chats import Chats
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.routers.tasks import generate_queries
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_chat_tags,
)
from open_webui.models.users import UserModel
@ -33,6 +43,7 @@ from open_webui.utils.task import (
tools_function_calling_generation_template,
)
from open_webui.utils.misc import (
get_message_list,
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
@ -41,6 +52,8 @@ from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.tasks import create_task
from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
from open_webui.constants import TASKS
@ -504,28 +517,178 @@ async def process_chat_payload(request, form_data, metadata, user, model):
return form_data, events
async def process_chat_response(response, events, metadata):
async def process_chat_response(request, response, user, events, metadata, tasks):
if not isinstance(response, StreamingResponse):
return response
content_type = response.headers["Content-Type"]
is_openai = "text/event-stream" in content_type
is_ollama = "application/x-ndjson" in content_type
if not is_openai and not is_ollama:
if not any(
content_type in response.headers["Content-Type"]
for content_type in ["text/event-stream", "application/x-ndjson"]
):
return response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n" if is_openai else f"{item}\n"
event_emitter = None
if "session_id" in metadata:
event_emitter = get_event_emitter(metadata)
for event in events:
yield wrap_item(json.dumps(event))
if event_emitter:
async for data in original_generator:
yield data
task_id = str(uuid4()) # Create a unique task ID.
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),
)
# Handle as a background task
async def post_response_handler(response, events):
try:
for event in events:
await event_emitter(
{
"type": "chat-completion",
"data": event,
}
)
content = ""
async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line
# Skip empty lines
if not data.strip():
continue
# "data: " is the prefix for each event
if not data.startswith("data: "):
continue
# Remove the prefix
data = data[len("data: ") :]
try:
data = json.loads(data)
value = (
data.get("choices", [])[0].get("delta", {}).get("content")
)
if value:
content = f"{content}{value}"
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": content,
},
)
except Exception as e:
done = "data: [DONE]" in line
if done:
data = {"done": True}
else:
continue
await event_emitter(
{
"type": "chat-completion",
"data": data,
}
)
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
message = message_map.get(metadata["message_id"])
if message:
messages = get_message_list(message_map, message.get("id"))
if TASKS.TITLE_GENERATION in tasks:
res = await generate_title(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res:
title = (
res.get("choices", [])[0]
.get("message", {})
.get("content", message.get("content", "New Chat"))
)
Chats.update_chat_title_by_id(metadata["chat_id"], title)
await event_emitter(
{
"type": "chat-title",
"data": title,
}
)
if TASKS.TAGS_GENERATION in tasks:
res = await generate_chat_tags(
request,
{
"model": message["model"],
"messages": messages,
"chat_id": metadata["chat_id"],
},
user,
)
if res:
tags_string = (
res.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
tags_string = tags_string[
tags_string.find("{") : tags_string.rfind("}") + 1
]
try:
tags = json.loads(tags_string).get("tags", [])
Chats.update_chat_tags_by_id(
metadata["chat_id"], tags, user
)
await event_emitter(
{
"type": "chat-tags",
"data": tags,
}
)
except Exception as e:
print(f"Error: {e}")
except asyncio.CancelledError:
print("Task was cancelled!")
await event_emitter({"type": "task-cancelled"})
if response.background is not None:
await response.background()
# background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task(post_response_handler(response, events))
return {"status": True, "task_id": task_id}
else:
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n"
for event in events:
yield wrap_item(json.dumps(event))
async for data in original_generator:
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),
)

View File

@ -7,6 +7,34 @@ from pathlib import Path
from typing import Callable, Optional
def get_message_list(messages, message_id):
"""
Reconstructs a list of messages in order up to the specified message_id.
:param message_id: ID of the message to reconstruct the chain
:param messages: Message history dict containing all messages
:return: List of ordered messages starting from the root to the given message
"""
# Find the message by its id
current_message = messages.get(message_id)
if not current_message:
return f"Message ID {message_id} not found in the history."
# Reconstruct the chain by following the parentId links
message_list = []
while current_message:
message_list.insert(
0, current_message
) # Insert the message at the beginning of the list
parent_id = current_message["parentId"]
current_message = messages.get(parent_id) if parent_id else None
return message_list
def get_messages_content(messages: list[dict]) -> str:
return "\n".join(
[

View File

@ -107,6 +107,42 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct
return res;
};
export const stopTask = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/tasks/stop/${id}`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const getTaskConfig = async (token: string = '') => {
let error = null;

View File

@ -277,21 +277,22 @@ export const generateOpenAIChatCompletion = async (
token: string = '',
body: object,
url: string = OPENAI_API_BASE_URL
): Promise<[Response | null, AbortController]> => {
const controller = new AbortController();
) => {
let error = null;
const res = await fetch(`${url}/chat/completions`, {
signal: controller.signal,
method: 'POST',
headers: {
Authorization: `Bearer ${token}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(body)
}).catch((err) => {
console.log(err);
error = err;
}).then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = `OpenAI: ${err?.detail ?? 'Network Problem'}`;
return null;
});
@ -299,7 +300,7 @@ export const generateOpenAIChatCompletion = async (
throw error;
}
return [res, controller];
return res;
};
export const synthesizeOpenAISpeech = async (

View File

@ -69,7 +69,8 @@
generateQueries,
chatAction,
generateMoACompletion,
generateTags
generateTags,
stopTask
} from '$lib/apis';
import Banner from '../common/Banner.svelte';
@ -88,7 +89,6 @@
let controlPane;
let controlPaneComponent;
let stopResponseFlag = false;
let autoScroll = true;
let processing = '';
let messagesContainerElement: HTMLDivElement;
@ -121,6 +121,8 @@
currentId: null
};
let taskId = null;
// Chat Input
let prompt = '';
let chatFiles = [];
@ -202,95 +204,107 @@
};
const chatEventHandler = async (event, cb) => {
console.log(event);
if (event.chat_id === $chatId) {
await tick();
console.log(event);
let message = history.messages[event.message_id];
const type = event?.data?.type ?? null;
const data = event?.data?.data ?? null;
if (message) {
const type = event?.data?.type ?? null;
const data = event?.data?.data ?? null;
if (type === 'status') {
if (message?.statusHistory) {
message.statusHistory.push(data);
} else {
message.statusHistory = [data];
}
} else if (type === 'source' || type === 'citation') {
if (data?.type === 'code_execution') {
// Code execution; update existing code execution by ID, or add new one.
if (!message?.code_executions) {
message.code_executions = [];
}
const existingCodeExecutionIndex = message.code_executions.findIndex(
(execution) => execution.id === data.id
);
if (existingCodeExecutionIndex !== -1) {
message.code_executions[existingCodeExecutionIndex] = data;
if (type === 'status') {
if (message?.statusHistory) {
message.statusHistory.push(data);
} else {
message.code_executions.push(data);
message.statusHistory = [data];
}
} else if (type === 'source' || type === 'citation') {
if (data?.type === 'code_execution') {
// Code execution; update existing code execution by ID, or add new one.
if (!message?.code_executions) {
message.code_executions = [];
}
message.code_executions = message.code_executions;
} else {
// Regular source.
if (message?.sources) {
message.sources.push(data);
const existingCodeExecutionIndex = message.code_executions.findIndex(
(execution) => execution.id === data.id
);
if (existingCodeExecutionIndex !== -1) {
message.code_executions[existingCodeExecutionIndex] = data;
} else {
message.code_executions.push(data);
}
message.code_executions = message.code_executions;
} else {
message.sources = [data];
// Regular source.
if (message?.sources) {
message.sources.push(data);
} else {
message.sources = [data];
}
}
}
} else if (type === 'message') {
message.content += data.content;
} else if (type === 'replace') {
message.content = data.content;
} else if (type === 'action') {
if (data.action === 'continue') {
const continueButton = document.getElementById('continue-response-button');
} else if (type === 'chat-completion') {
chatCompletionEventHandler(data, message, event.chat_id);
} else if (type === 'chat-title') {
chatTitle.set(data);
currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage));
} else if (type === 'chat-tags') {
chat = await getChatById(localStorage.token, $chatId);
allTags.set(await getAllTags(localStorage.token));
} else if (type === 'message') {
message.content += data.content;
} else if (type === 'replace') {
message.content = data.content;
} else if (type === 'action') {
if (data.action === 'continue') {
const continueButton = document.getElementById('continue-response-button');
if (continueButton) {
continueButton.click();
if (continueButton) {
continueButton.click();
}
}
}
} else if (type === 'confirmation') {
eventCallback = cb;
} else if (type === 'confirmation') {
eventCallback = cb;
eventConfirmationInput = false;
showEventConfirmation = true;
eventConfirmationInput = false;
showEventConfirmation = true;
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
} else if (type === 'execute') {
eventCallback = cb;
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
} else if (type === 'execute') {
eventCallback = cb;
try {
// Use Function constructor to evaluate code in a safer way
const asyncFunction = new Function(`return (async () => { ${data.code} })()`);
const result = await asyncFunction(); // Await the result of the async function
try {
// Use Function constructor to evaluate code in a safer way
const asyncFunction = new Function(`return (async () => { ${data.code} })()`);
const result = await asyncFunction(); // Await the result of the async function
if (cb) {
cb(result);
if (cb) {
cb(result);
}
} catch (error) {
console.error('Error executing code:', error);
}
} catch (error) {
console.error('Error executing code:', error);
} else if (type === 'input') {
eventCallback = cb;
eventConfirmationInput = true;
showEventConfirmation = true;
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
eventConfirmationInputPlaceholder = data.placeholder;
eventConfirmationInputValue = data?.value ?? '';
} else {
console.log('Unknown message type', data);
}
} else if (type === 'input') {
eventCallback = cb;
eventConfirmationInput = true;
showEventConfirmation = true;
eventConfirmationTitle = data.title;
eventConfirmationMessage = data.message;
eventConfirmationInputPlaceholder = data.placeholder;
eventConfirmationInputValue = data?.value ?? '';
} else {
console.log('Unknown message type', data);
history.messages[event.message_id] = message;
}
history.messages[event.message_id] = message;
}
};
@ -956,6 +970,119 @@
}
};
const chatCompletionEventHandler = async (data, message, chatId) => {
const { id, done, choices, sources, selectedModelId, error, usage } = data;
if (error) {
await handleOpenAIError(error, message);
}
if (sources) {
message.sources = sources;
// Only remove status if it was initially set
if (model?.info?.meta?.knowledge ?? false) {
message.statusHistory = message.statusHistory.filter(
(status) => status.action !== 'knowledge_search'
);
}
}
if (choices) {
const value = choices[0]?.delta?.content ?? '';
if (message.content == '' && value == '\n') {
console.log('Empty response');
} else {
message.content += value;
if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
navigator.vibrate(5);
}
// Emit chat event for TTS
const messageContentParts = getMessageContentParts(
message.content,
$config?.audio?.tts?.split_on ?? 'punctuation'
);
messageContentParts.pop();
// dispatch only last sentence and make sure it hasn't been dispatched before
if (
messageContentParts.length > 0 &&
messageContentParts[messageContentParts.length - 1] !== message.lastSentence
) {
message.lastSentence = messageContentParts[messageContentParts.length - 1];
eventTarget.dispatchEvent(
new CustomEvent('chat', {
detail: {
id: message.id,
content: messageContentParts[messageContentParts.length - 1]
}
})
);
}
}
}
if (selectedModelId) {
message.selectedModelId = selectedModelId;
message.arena = true;
}
if (usage) {
message.usage = usage;
}
if (done) {
message.done = true;
if ($settings.notificationEnabled && !document.hasFocus()) {
new Notification(`${message.model}`, {
body: message.content,
icon: `${WEBUI_BASE_URL}/static/favicon.png`
});
}
if ($settings.responseAutoCopy) {
copyToClipboard(message.content);
}
if ($settings.responseAutoPlayback && !$showCallOverlay) {
await tick();
document.getElementById(`speak-button-${message.id}`)?.click();
}
// Emit chat event for TTS
let lastMessageContentPart =
getMessageContentParts(message.content, $config?.audio?.tts?.split_on ?? 'punctuation')?.at(
-1
) ?? '';
if (lastMessageContentPart) {
eventTarget.dispatchEvent(
new CustomEvent('chat', {
detail: { id: message.id, content: lastMessageContentPart }
})
);
}
eventTarget.dispatchEvent(
new CustomEvent('chat:finish', {
detail: {
id: message.id,
content: message.content
}
})
);
history.messages[message.id] = message;
await chatCompletedHandler(chatId, message.model, message.id, createMessagesList(message.id));
}
history.messages[message.id] = message;
console.log(data);
if (autoScroll) {
scrollToBottom();
}
};
//////////////////////////
// Chat functions
//////////////////////////
@ -1061,6 +1188,7 @@
chatInput?.focus();
saveSessionSelectedModels();
await sendPrompt(userPrompt, userMessageId, { newChat: true });
};
@ -1076,6 +1204,8 @@
history.messages[history.currentId].role === 'user'
) {
await initChatHandler();
} else {
await saveChatHandler($chatId);
}
// If modelId is provided, use it, else use selected model
@ -1122,6 +1252,9 @@
}
await tick();
// Save chat after all messages have been created
await saveChatHandler($chatId);
const _chatId = JSON.parse(JSON.stringify($chatId));
await Promise.all(
selectedModelIds.map(async (modelId, _modelIdx) => {
@ -1178,7 +1311,7 @@
await getWebSearchResults(model.id, parentId, responseMessageId);
}
await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
await sendPromptSocket(model, responseMessageId, _chatId);
if (chatEventEmitter) clearInterval(chatEventEmitter);
} else {
toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
@ -1190,9 +1323,7 @@
chats.set(await getChatList(localStorage.token, $currentChatPage));
};
const sendPromptOpenAI = async (model, userPrompt, responseMessageId, _chatId) => {
let _response = null;
const sendPromptSocket = async (model, responseMessageId, _chatId) => {
const responseMessage = history.messages[responseMessageId];
const userMessage = history.messages[responseMessage.parentId];
@ -1243,7 +1374,6 @@
);
scrollToBottom();
eventTarget.dispatchEvent(
new CustomEvent('chat:start', {
detail: {
@ -1253,278 +1383,133 @@
);
await tick();
try {
const stream =
model?.info?.params?.stream_response ??
$settings?.params?.stream_response ??
params?.stream_response ??
true;
const stream =
model?.info?.params?.stream_response ??
$settings?.params?.stream_response ??
params?.stream_response ??
true;
const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token,
{
stream: stream,
model: model.id,
messages: [
params?.system || $settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content: `${promptTemplate(
params?.system ?? $settings?.system ?? '',
$user.name,
$settings?.userLocation
? await getAndUpdateUserLocation(localStorage.token)
: undefined
)}${
(responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}`
: ''
}`
}
: undefined,
...createMessagesList(responseMessageId)
]
.filter((message) => message?.content?.trim())
.map((message, idx, arr) => ({
role: message.role,
...((message.files?.filter((file) => file.type === 'image').length > 0 ?? false) &&
message.role === 'user'
? {
content: [
{
type: 'text',
text: message?.merged?.content ?? message.content
},
...message.files
.filter((file) => file.type === 'image')
.map((file) => ({
type: 'image_url',
image_url: {
url: file.url
}
}))
]
}
: {
content: message?.merged?.content ?? message.content
})
})),
params: {
...$settings?.params,
...params,
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
stop:
(params?.stop ?? $settings?.params?.stop ?? undefined)
? (
params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop
).map((str) =>
decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
)
const messages = [
params?.system || $settings.system || (responseMessage?.userContext ?? null)
? {
role: 'system',
content: `${promptTemplate(
params?.system ?? $settings?.system ?? '',
$user.name,
$settings?.userLocation
? await getAndUpdateUserLocation(localStorage.token)
: undefined
},
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
session_id: $socket?.id,
chat_id: $chatId,
id: responseMessageId,
...(stream && (model.info?.meta?.capabilities?.usage ?? false)
? {
stream_options: {
include_usage: true
}
}
: {})
},
`${WEBUI_BASE_URL}/api`
);
// Wait until history/message have been updated
await tick();
scrollToBottom();
if (res && res.ok && res.body) {
if (!stream) {
const response = await res.json();
console.log(response);
responseMessage.content = response.choices[0].message.content;
responseMessage.info = { ...response.usage, openai: true };
responseMessage.done = true;
} else {
const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
for await (const update of textStream) {
const { value, done, sources, selectedModelId, error, usage } = update;
if (error) {
await handleOpenAIError(error, null, model, responseMessage);
break;
}
if (done || stopResponseFlag || _chatId !== $chatId) {
responseMessage.done = true;
history.messages[responseMessageId] = responseMessage;
if (stopResponseFlag) {
controller.abort('User: Stop Response');
}
_response = responseMessage.content;
break;
}
if (usage) {
responseMessage.usage = usage;
}
if (selectedModelId) {
responseMessage.selectedModelId = selectedModelId;
responseMessage.arena = true;
continue;
}
if (sources) {
responseMessage.sources = sources;
// Only remove status if it was initially set
if (model?.info?.meta?.knowledge ?? false) {
responseMessage.statusHistory = responseMessage.statusHistory.filter(
(status) => status.action !== 'knowledge_search'
);
}
continue;
}
if (responseMessage.content == '' && value == '\n') {
continue;
} else {
responseMessage.content += value;
if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
navigator.vibrate(5);
}
const messageContentParts = getMessageContentParts(
responseMessage.content,
$config?.audio?.tts?.split_on ?? 'punctuation'
);
messageContentParts.pop();
// dispatch only last sentence and make sure it hasn't been dispatched before
if (
messageContentParts.length > 0 &&
messageContentParts[messageContentParts.length - 1] !== responseMessage.lastSentence
) {
responseMessage.lastSentence = messageContentParts[messageContentParts.length - 1];
eventTarget.dispatchEvent(
new CustomEvent('chat', {
detail: {
id: responseMessageId,
content: messageContentParts[messageContentParts.length - 1]
}
})
);
}
history.messages[responseMessageId] = responseMessage;
}
if (autoScroll) {
scrollToBottom();
}
)}${
(responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}`
: ''
}`
}
}
: undefined,
...createMessagesList(responseMessageId)
]
.filter((message) => message?.content?.trim())
.map((message, idx, arr) => ({
role: message.role,
...((message.files?.filter((file) => file.type === 'image').length > 0 ?? false) &&
message.role === 'user'
? {
content: [
{
type: 'text',
text: message?.merged?.content ?? message.content
},
...message.files
.filter((file) => file.type === 'image')
.map((file) => ({
type: 'image_url',
image_url: {
url: file.url
}
}))
]
}
: {
content: message?.merged?.content ?? message.content
})
}));
if ($settings.notificationEnabled && !document.hasFocus()) {
const notification = new Notification(`${model.id}`, {
body: responseMessage.content,
icon: `${WEBUI_BASE_URL}/static/favicon.png`
});
}
const res = await generateOpenAIChatCompletion(
localStorage.token,
{
stream: stream,
model: model.id,
messages: messages,
params: {
...$settings?.params,
...params,
if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content);
}
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
stop:
(params?.stop ?? $settings?.params?.stop ?? undefined)
? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map(
(str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
)
: undefined
},
if ($settings.responseAutoPlayback && !$showCallOverlay) {
await tick();
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
session_id: $socket?.id,
chat_id: $chatId,
id: responseMessageId,
document.getElementById(`speak-button-${responseMessage.id}`)?.click();
}
} else {
await handleOpenAIError(null, res, model, responseMessage);
}
} catch (error) {
await handleOpenAIError(error, null, model, responseMessage);
...(!$temporaryChatEnabled && messages.length == 1 && selectedModels[0] === model.id
? {
background_tasks: {
title_generation: $settings?.title?.auto ?? true,
tags_generation: $settings?.autoTags ?? true
}
}
: {}),
...(stream && (model.info?.meta?.capabilities?.usage ?? false)
? {
stream_options: {
include_usage: true
}
}
: {})
},
`${WEBUI_BASE_URL}/api`
).catch((error) => {
responseMessage.error = {
content: error
};
responseMessage.done = true;
return null;
});
console.log(res);
if (res) {
taskId = res.task_id;
}
await saveChatHandler(_chatId);
history.messages[responseMessageId] = responseMessage;
await chatCompletedHandler(
_chatId,
model.id,
responseMessageId,
createMessagesList(responseMessageId)
);
stopResponseFlag = false;
// Wait until history/message have been updated
await tick();
scrollToBottom();
let lastMessageContentPart =
getMessageContentParts(
responseMessage.content,
$config?.audio?.tts?.split_on ?? 'punctuation'
)?.at(-1) ?? '';
if (lastMessageContentPart) {
eventTarget.dispatchEvent(
new CustomEvent('chat', {
detail: { id: responseMessageId, content: lastMessageContentPart }
})
);
}
eventTarget.dispatchEvent(
new CustomEvent('chat:finish', {
detail: {
id: responseMessageId,
content: responseMessage.content
}
})
);
if (autoScroll) {
scrollToBottom();
}
const messages = createMessagesList(responseMessageId);
if (messages.length == 2 && selectedModels[0] === model.id) {
window.history.replaceState(history.state, '', `/c/${_chatId}`);
const title = await generateChatTitle(messages);
await setChatTitle(_chatId, title);
if ($settings?.autoTags ?? true) {
await setChatTags(messages);
}
}
return _response;
// if ($settings?.autoTags ?? true) {
// await setChatTags(messages);
// }
// }
};
const handleOpenAIError = async (error, res: Response | null, model, responseMessage) => {
const handleOpenAIError = async (error, responseMessage) => {
let errorMessage = '';
let innerError;
if (error) {
innerError = error;
} else if (res !== null) {
innerError = await res.json();
}
console.error(innerError);
if ('detail' in innerError) {
toast.error(innerError.detail);
@ -1543,12 +1528,7 @@
}
responseMessage.error = {
content:
$i18n.t(`Uh-oh! There was an issue connecting to {{provider}}.`, {
provider: model.name ?? model.id
}) +
'\n' +
errorMessage
content: $i18n.t(`Uh-oh! There was an issue with the response.`) + '\n' + errorMessage
};
responseMessage.done = true;
@ -1562,8 +1542,15 @@
};
const stopResponse = () => {
stopResponseFlag = true;
console.log('stopResponse');
if (taskId) {
const res = stopTask(localStorage.token, taskId).catch((error) => {
return null;
});
if (res) {
taskId = null;
}
}
};
const submitMessage = async (parentId, prompt) => {
@ -1628,12 +1615,7 @@
.at(0);
if (model) {
await sendPromptOpenAI(
model,
history.messages[responseMessage.parentId].content,
responseMessage.id,
_chatId
);
await sendPromptSocket(model, responseMessage.id, _chatId);
}
}
};
@ -1685,38 +1667,6 @@
}
};
const generateChatTitle = async (messages) => {
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
if ($settings?.title?.auto ?? true) {
const modelId = selectedModels[0];
const title = await generateTitle(localStorage.token, modelId, messages, $chatId).catch(
(error) => {
console.error(error);
return lastUserMessage?.content ?? 'New Chat';
}
);
return title ? title : (lastUserMessage?.content ?? 'New Chat');
} else {
return lastUserMessage?.content ?? 'New Chat';
}
};
const setChatTitle = async (_chatId, title) => {
if (_chatId === $chatId) {
chatTitle.set(title);
}
if (!$temporaryChatEnabled) {
chat = await updateChatById(localStorage.token, _chatId, { title: title });
currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage));
}
};
const setChatTags = async (messages) => {
if (!$temporaryChatEnabled) {
const currentTags = await getTagsById(localStorage.token, $chatId);
@ -1856,6 +1806,8 @@
currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage));
await chatId.set(chat.id);
window.history.replaceState(history.state, '', `/c/${chat.id}`);
} else {
await chatId.set('local');
}