diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7d3a84722..807c87dcc 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -30,7 +30,9 @@ from fastapi import ( UploadFile, status, applications, + BackgroundTasks, ) + from fastapi.openapi.docs import get_swagger_ui_html from fastapi.middleware.cors import CORSMiddleware @@ -295,6 +297,7 @@ from open_webui.utils.auth import ( from open_webui.utils.oauth import oauth_manager from open_webui.utils.security_headers import SecurityHeadersMiddleware +from open_webui.tasks import stop_task, list_tasks # Import from tasks.py if SAFE_MODE: print("SAFE MODE ENABLED") @@ -822,11 +825,11 @@ async def chat_completion( request: Request, form_data: dict, user=Depends(get_verified_user), - bypass_filter: bool = False, ): if not request.app.state.MODELS: await get_all_models(request) + tasks = form_data.pop("background_tasks", None) try: model_id = form_data.get("model", None) if model_id not in request.app.state.MODELS: @@ -834,13 +837,14 @@ async def chat_completion( model = request.app.state.MODELS[model_id] # Check if user has access to the model - if not bypass_filter and user.role == "user": + if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": try: check_model_access(user, model) except Exception as e: raise e metadata = { + "user_id": user.id, "chat_id": form_data.pop("chat_id", None), "message_id": form_data.pop("id", None), "session_id": form_data.pop("session_id", None), @@ -859,10 +863,10 @@ async def chat_completion( ) try: - response = await chat_completion_handler( - request, form_data, user, bypass_filter + response = await chat_completion_handler(request, form_data, user) + return await process_chat_response( + request, response, user, events, metadata, tasks ) - return await process_chat_response(response, events, metadata) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -901,6 +905,20 @@ async def chat_action( ) +@app.post("/api/tasks/stop/{task_id}") +async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)): + try: + result = await stop_task(task_id) # Use the function from tasks.py + return result + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + + +@app.get("/api/tasks") +async def list_tasks_endpoint(user=Depends(get_verified_user)): + return {"tasks": list_tasks()} # Use the function from tasks.py + + ################################## # # Config Endpoints diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 3e621a150..75e5114b8 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -168,6 +168,66 @@ class ChatTable: except Exception: return None + def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]: + chat = self.get_chat_by_id(id) + if chat is None: + return None + + chat = chat.chat + chat["title"] = title + + return self.update_chat_by_id(id, chat) + + def update_chat_tags_by_id( + self, id: str, tags: list[str], user + ) -> Optional[ChatModel]: + chat = self.get_chat_by_id(id) + if chat is None: + return None + + self.delete_all_tags_by_id_and_user_id(id, user.id) + + for tag in chat.meta.get("tags", []): + if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + + for tag_name in tags: + if tag_name.lower() == "none": + continue + + self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name) + return self.get_chat_by_id(id) + + def get_messages_by_chat_id(self, id: str) -> Optional[dict]: + chat = self.get_chat_by_id(id) + if chat is None: + return None + + return chat.chat.get("history", {}).get("messages", {}) or {} + + def upsert_message_to_chat_by_id_and_message_id( + self, id: str, message_id: str, message: dict + ) -> Optional[ChatModel]: + chat = self.get_chat_by_id(id) + if chat is None: + return None + + chat = chat.chat + history = chat.get("history", {}) + + if message_id in history.get("messages", {}): + history["messages"][message_id] = { + **history["messages"][message_id], + **message, + } + else: + history["messages"][message_id] = message + + history["currentId"] = message_id + + chat["history"] = history + return self.update_chat_by_id(id, chat) + def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: with get_db() as db: # Get the existing chat to share diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 0c152c1f0..275146c72 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -82,6 +82,16 @@ async def send_get_request(url, key=None): return None +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + async def send_post_request( url: str, payload: Union[str, bytes], @@ -89,14 +99,6 @@ async def send_post_request( key: Optional[str] = None, content_type: Optional[str] = None, ): - async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], - ): - if response: - response.close() - if session: - await session.close() r = None try: diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index d043ce066..4bdcbabed 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -217,15 +217,19 @@ async def disconnect(sid): def get_event_emitter(request_info): async def __event_emitter__(event_data): - await sio.emit( - "chat-events", - { - "chat_id": request_info["chat_id"], - "message_id": request_info["message_id"], - "data": event_data, - }, - to=request_info["session_id"], - ) + user_id = request_info["user_id"] + session_ids = USER_POOL.get(user_id, []) + + for session_id in session_ids: + await sio.emit( + "chat-events", + { + "chat_id": request_info["chat_id"], + "message_id": request_info["message_id"], + "data": event_data, + }, + to=session_id, + ) return __event_emitter__ diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py new file mode 100644 index 000000000..2740ecb5a --- /dev/null +++ b/backend/open_webui/tasks.py @@ -0,0 +1,61 @@ +# tasks.py +import asyncio +from typing import Dict +from uuid import uuid4 + +# A dictionary to keep track of active tasks +tasks: Dict[str, asyncio.Task] = {} + + +def cleanup_task(task_id: str): + """ + Remove a completed or canceled task from the global `tasks` dictionary. + """ + tasks.pop(task_id, None) # Remove the task if it exists + + +def create_task(coroutine): + """ + Create a new asyncio task and add it to the global task dictionary. + """ + task_id = str(uuid4()) # Generate a unique ID for the task + task = asyncio.create_task(coroutine) # Create the task + + # Add a done callback for cleanup + task.add_done_callback(lambda t: cleanup_task(task_id)) + + tasks[task_id] = task + return task_id, task + + +def get_task(task_id: str): + """ + Retrieve a task by its task ID. + """ + return tasks.get(task_id) + + +def list_tasks(): + """ + List all currently active task IDs. + """ + return list(tasks.keys()) + + +async def stop_task(task_id: str): + """ + Cancel a running task and remove it from the global task list. + """ + task = tasks.get(task_id) + if not task: + raise ValueError(f"Task with ID {task_id} not found.") + + task.cancel() # Request task cancellation + try: + await task # Wait for the task to handle the cancellation + except asyncio.CancelledError: + # Task successfully canceled + tasks.pop(task_id, None) # Remove it from the dictionary + return {"status": True, "message": f"Task {task_id} successfully stopped."} + + return {"status": False, "message": f"Failed to stop task {task_id}."} diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 56904d1d8..c81e56afb 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -117,7 +117,9 @@ async def generate_chat_completion( form_data, user, bypass_filter=True ) return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" + stream_wrapper(response.body_iterator), + media_type="text/event-stream", + background=response.background, ) else: return { @@ -141,6 +143,7 @@ async def generate_chat_completion( return StreamingResponse( convert_streaming_response_ollama_to_openai(response), headers=dict(response.headers), + background=response.background, ) else: return convert_response_ollama_to_openai(response) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 4115c7c2b..9ed493401 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -2,21 +2,31 @@ import time import logging import sys +import asyncio from aiocache import cached from typing import Any, Optional import random import json import inspect +from uuid import uuid4 + from fastapi import Request +from fastapi import BackgroundTasks + from starlette.responses import Response, StreamingResponse +from open_webui.models.chats import Chats from open_webui.socket.main import ( get_event_call, get_event_emitter, ) -from open_webui.routers.tasks import generate_queries +from open_webui.routers.tasks import ( + generate_queries, + generate_title, + generate_chat_tags, +) from open_webui.models.users import UserModel @@ -33,6 +43,7 @@ from open_webui.utils.task import ( tools_function_calling_generation_template, ) from open_webui.utils.misc import ( + get_message_list, add_or_update_system_message, get_last_user_message, prepend_to_first_user_message_content, @@ -41,6 +52,8 @@ from open_webui.utils.tools import get_tools from open_webui.utils.plugin import load_function_module_by_id +from open_webui.tasks import create_task + from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL from open_webui.constants import TASKS @@ -504,28 +517,178 @@ async def process_chat_payload(request, form_data, metadata, user, model): return form_data, events -async def process_chat_response(response, events, metadata): +async def process_chat_response(request, response, user, events, metadata, tasks): if not isinstance(response, StreamingResponse): return response - content_type = response.headers["Content-Type"] - is_openai = "text/event-stream" in content_type - is_ollama = "application/x-ndjson" in content_type - - if not is_openai and not is_ollama: + if not any( + content_type in response.headers["Content-Type"] + for content_type in ["text/event-stream", "application/x-ndjson"] + ): return response - async def stream_wrapper(original_generator, events): - def wrap_item(item): - return f"data: {item}\n\n" if is_openai else f"{item}\n" + event_emitter = None + if "session_id" in metadata: + event_emitter = get_event_emitter(metadata) - for event in events: - yield wrap_item(json.dumps(event)) + if event_emitter: - async for data in original_generator: - yield data + task_id = str(uuid4()) # Create a unique task ID. - return StreamingResponse( - stream_wrapper(response.body_iterator, events), - headers=dict(response.headers), - ) + # Handle as a background task + async def post_response_handler(response, events): + try: + for event in events: + await event_emitter( + { + "type": "chat-completion", + "data": event, + } + ) + + content = "" + async for line in response.body_iterator: + line = line.decode("utf-8") if isinstance(line, bytes) else line + data = line + + # Skip empty lines + if not data.strip(): + continue + + # "data: " is the prefix for each event + if not data.startswith("data: "): + continue + + # Remove the prefix + data = data[len("data: ") :] + + try: + data = json.loads(data) + value = ( + data.get("choices", [])[0].get("delta", {}).get("content") + ) + + if value: + content = f"{content}{value}" + + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "content": content, + }, + ) + + except Exception as e: + done = "data: [DONE]" in line + + if done: + data = {"done": True} + else: + continue + + await event_emitter( + { + "type": "chat-completion", + "data": data, + } + ) + + message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) + message = message_map.get(metadata["message_id"]) + + if message: + messages = get_message_list(message_map, message.get("id")) + + if TASKS.TITLE_GENERATION in tasks: + res = await generate_title( + request, + { + "model": message["model"], + "messages": messages, + "chat_id": metadata["chat_id"], + }, + user, + ) + + if res: + title = ( + res.get("choices", [])[0] + .get("message", {}) + .get("content", message.get("content", "New Chat")) + ) + + Chats.update_chat_title_by_id(metadata["chat_id"], title) + + await event_emitter( + { + "type": "chat-title", + "data": title, + } + ) + + if TASKS.TAGS_GENERATION in tasks: + res = await generate_chat_tags( + request, + { + "model": message["model"], + "messages": messages, + "chat_id": metadata["chat_id"], + }, + user, + ) + + if res: + tags_string = ( + res.get("choices", [])[0] + .get("message", {}) + .get("content", "") + ) + + tags_string = tags_string[ + tags_string.find("{") : tags_string.rfind("}") + 1 + ] + + try: + tags = json.loads(tags_string).get("tags", []) + Chats.update_chat_tags_by_id( + metadata["chat_id"], tags, user + ) + + await event_emitter( + { + "type": "chat-tags", + "data": tags, + } + ) + except Exception as e: + print(f"Error: {e}") + + except asyncio.CancelledError: + print("Task was cancelled!") + await event_emitter({"type": "task-cancelled"}) + + if response.background is not None: + await response.background() + + # background_tasks.add_task(post_response_handler, response, events) + task_id, _ = create_task(post_response_handler(response, events)) + return {"status": True, "task_id": task_id} + + else: + # Fallback to the original response + async def stream_wrapper(original_generator, events): + def wrap_item(item): + return f"data: {item}\n\n" + + for event in events: + yield wrap_item(json.dumps(event)) + + async for data in original_generator: + yield data + + return StreamingResponse( + stream_wrapper(response.body_iterator, events), + headers=dict(response.headers), + ) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index aba696f60..08abde0cc 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -7,6 +7,34 @@ from pathlib import Path from typing import Callable, Optional +def get_message_list(messages, message_id): + """ + Reconstructs a list of messages in order up to the specified message_id. + + :param message_id: ID of the message to reconstruct the chain + :param messages: Message history dict containing all messages + :return: List of ordered messages starting from the root to the given message + """ + + # Find the message by its id + current_message = messages.get(message_id) + + if not current_message: + return f"Message ID {message_id} not found in the history." + + # Reconstruct the chain by following the parentId links + message_list = [] + + while current_message: + message_list.insert( + 0, current_message + ) # Insert the message at the beginning of the list + parent_id = current_message["parentId"] + current_message = messages.get(parent_id) if parent_id else None + + return message_list + + def get_messages_content(messages: list[dict]) -> str: return "\n".join( [ diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index d06fbf3d7..26cfe3bef 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -107,6 +107,42 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct return res; }; + + +export const stopTask = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/tasks/stop/${id}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + + export const getTaskConfig = async (token: string = '') => { let error = null; diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 1988dc0c3..b9b116d6d 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -277,29 +277,30 @@ export const generateOpenAIChatCompletion = async ( token: string = '', body: object, url: string = OPENAI_API_BASE_URL -): Promise<[Response | null, AbortController]> => { - const controller = new AbortController(); +) => { let error = null; const res = await fetch(`${url}/chat/completions`, { - signal: controller.signal, method: 'POST', headers: { Authorization: `Bearer ${token}`, 'Content-Type': 'application/json' }, body: JSON.stringify(body) - }).catch((err) => { - console.log(err); - error = err; - return null; + }).then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = `OpenAI: ${err?.detail ?? 'Network Problem'}`; + return null; }); if (error) { throw error; } - return [res, controller]; + return res; }; export const synthesizeOpenAISpeech = async ( diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index ebe3c6f30..6cceae496 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -69,7 +69,8 @@ generateQueries, chatAction, generateMoACompletion, - generateTags + generateTags, + stopTask } from '$lib/apis'; import Banner from '../common/Banner.svelte'; @@ -88,7 +89,6 @@ let controlPane; let controlPaneComponent; - let stopResponseFlag = false; let autoScroll = true; let processing = ''; let messagesContainerElement: HTMLDivElement; @@ -121,6 +121,8 @@ currentId: null }; + let taskId = null; + // Chat Input let prompt = ''; let chatFiles = []; @@ -202,95 +204,107 @@ }; const chatEventHandler = async (event, cb) => { + console.log(event); + if (event.chat_id === $chatId) { await tick(); - console.log(event); let message = history.messages[event.message_id]; - const type = event?.data?.type ?? null; - const data = event?.data?.data ?? null; + if (message) { + const type = event?.data?.type ?? null; + const data = event?.data?.data ?? null; - if (type === 'status') { - if (message?.statusHistory) { - message.statusHistory.push(data); - } else { - message.statusHistory = [data]; - } - } else if (type === 'source' || type === 'citation') { - if (data?.type === 'code_execution') { - // Code execution; update existing code execution by ID, or add new one. - if (!message?.code_executions) { - message.code_executions = []; - } - - const existingCodeExecutionIndex = message.code_executions.findIndex( - (execution) => execution.id === data.id - ); - - if (existingCodeExecutionIndex !== -1) { - message.code_executions[existingCodeExecutionIndex] = data; + if (type === 'status') { + if (message?.statusHistory) { + message.statusHistory.push(data); } else { - message.code_executions.push(data); + message.statusHistory = [data]; } + } else if (type === 'source' || type === 'citation') { + if (data?.type === 'code_execution') { + // Code execution; update existing code execution by ID, or add new one. + if (!message?.code_executions) { + message.code_executions = []; + } - message.code_executions = message.code_executions; - } else { - // Regular source. - if (message?.sources) { - message.sources.push(data); + const existingCodeExecutionIndex = message.code_executions.findIndex( + (execution) => execution.id === data.id + ); + + if (existingCodeExecutionIndex !== -1) { + message.code_executions[existingCodeExecutionIndex] = data; + } else { + message.code_executions.push(data); + } + + message.code_executions = message.code_executions; } else { - message.sources = [data]; + // Regular source. + if (message?.sources) { + message.sources.push(data); + } else { + message.sources = [data]; + } } - } - } else if (type === 'message') { - message.content += data.content; - } else if (type === 'replace') { - message.content = data.content; - } else if (type === 'action') { - if (data.action === 'continue') { - const continueButton = document.getElementById('continue-response-button'); + } else if (type === 'chat-completion') { + chatCompletionEventHandler(data, message, event.chat_id); + } else if (type === 'chat-title') { + chatTitle.set(data); + currentChatPage.set(1); + await chats.set(await getChatList(localStorage.token, $currentChatPage)); + } else if (type === 'chat-tags') { + chat = await getChatById(localStorage.token, $chatId); + allTags.set(await getAllTags(localStorage.token)); + } else if (type === 'message') { + message.content += data.content; + } else if (type === 'replace') { + message.content = data.content; + } else if (type === 'action') { + if (data.action === 'continue') { + const continueButton = document.getElementById('continue-response-button'); - if (continueButton) { - continueButton.click(); + if (continueButton) { + continueButton.click(); + } } - } - } else if (type === 'confirmation') { - eventCallback = cb; + } else if (type === 'confirmation') { + eventCallback = cb; - eventConfirmationInput = false; - showEventConfirmation = true; + eventConfirmationInput = false; + showEventConfirmation = true; - eventConfirmationTitle = data.title; - eventConfirmationMessage = data.message; - } else if (type === 'execute') { - eventCallback = cb; + eventConfirmationTitle = data.title; + eventConfirmationMessage = data.message; + } else if (type === 'execute') { + eventCallback = cb; - try { - // Use Function constructor to evaluate code in a safer way - const asyncFunction = new Function(`return (async () => { ${data.code} })()`); - const result = await asyncFunction(); // Await the result of the async function + try { + // Use Function constructor to evaluate code in a safer way + const asyncFunction = new Function(`return (async () => { ${data.code} })()`); + const result = await asyncFunction(); // Await the result of the async function - if (cb) { - cb(result); + if (cb) { + cb(result); + } + } catch (error) { + console.error('Error executing code:', error); } - } catch (error) { - console.error('Error executing code:', error); + } else if (type === 'input') { + eventCallback = cb; + + eventConfirmationInput = true; + showEventConfirmation = true; + + eventConfirmationTitle = data.title; + eventConfirmationMessage = data.message; + eventConfirmationInputPlaceholder = data.placeholder; + eventConfirmationInputValue = data?.value ?? ''; + } else { + console.log('Unknown message type', data); } - } else if (type === 'input') { - eventCallback = cb; - eventConfirmationInput = true; - showEventConfirmation = true; - - eventConfirmationTitle = data.title; - eventConfirmationMessage = data.message; - eventConfirmationInputPlaceholder = data.placeholder; - eventConfirmationInputValue = data?.value ?? ''; - } else { - console.log('Unknown message type', data); + history.messages[event.message_id] = message; } - - history.messages[event.message_id] = message; } }; @@ -956,6 +970,119 @@ } }; + const chatCompletionEventHandler = async (data, message, chatId) => { + const { id, done, choices, sources, selectedModelId, error, usage } = data; + + if (error) { + await handleOpenAIError(error, message); + } + + if (sources) { + message.sources = sources; + // Only remove status if it was initially set + if (model?.info?.meta?.knowledge ?? false) { + message.statusHistory = message.statusHistory.filter( + (status) => status.action !== 'knowledge_search' + ); + } + } + + if (choices) { + const value = choices[0]?.delta?.content ?? ''; + if (message.content == '' && value == '\n') { + console.log('Empty response'); + } else { + message.content += value; + + if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) { + navigator.vibrate(5); + } + + // Emit chat event for TTS + const messageContentParts = getMessageContentParts( + message.content, + $config?.audio?.tts?.split_on ?? 'punctuation' + ); + messageContentParts.pop(); + // dispatch only last sentence and make sure it hasn't been dispatched before + if ( + messageContentParts.length > 0 && + messageContentParts[messageContentParts.length - 1] !== message.lastSentence + ) { + message.lastSentence = messageContentParts[messageContentParts.length - 1]; + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { + id: message.id, + content: messageContentParts[messageContentParts.length - 1] + } + }) + ); + } + } + } + + if (selectedModelId) { + message.selectedModelId = selectedModelId; + message.arena = true; + } + + if (usage) { + message.usage = usage; + } + + if (done) { + message.done = true; + + if ($settings.notificationEnabled && !document.hasFocus()) { + new Notification(`${message.model}`, { + body: message.content, + icon: `${WEBUI_BASE_URL}/static/favicon.png` + }); + } + + if ($settings.responseAutoCopy) { + copyToClipboard(message.content); + } + + if ($settings.responseAutoPlayback && !$showCallOverlay) { + await tick(); + document.getElementById(`speak-button-${message.id}`)?.click(); + } + + // Emit chat event for TTS + let lastMessageContentPart = + getMessageContentParts(message.content, $config?.audio?.tts?.split_on ?? 'punctuation')?.at( + -1 + ) ?? ''; + if (lastMessageContentPart) { + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: message.id, content: lastMessageContentPart } + }) + ); + } + eventTarget.dispatchEvent( + new CustomEvent('chat:finish', { + detail: { + id: message.id, + content: message.content + } + }) + ); + + history.messages[message.id] = message; + await chatCompletedHandler(chatId, message.model, message.id, createMessagesList(message.id)); + } + + history.messages[message.id] = message; + + console.log(data); + if (autoScroll) { + scrollToBottom(); + } + }; + ////////////////////////// // Chat functions ////////////////////////// @@ -1061,6 +1188,7 @@ chatInput?.focus(); saveSessionSelectedModels(); + await sendPrompt(userPrompt, userMessageId, { newChat: true }); }; @@ -1076,6 +1204,8 @@ history.messages[history.currentId].role === 'user' ) { await initChatHandler(); + } else { + await saveChatHandler($chatId); } // If modelId is provided, use it, else use selected model @@ -1122,6 +1252,9 @@ } await tick(); + // Save chat after all messages have been created + await saveChatHandler($chatId); + const _chatId = JSON.parse(JSON.stringify($chatId)); await Promise.all( selectedModelIds.map(async (modelId, _modelIdx) => { @@ -1178,7 +1311,7 @@ await getWebSearchResults(model.id, parentId, responseMessageId); } - await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); + await sendPromptSocket(model, responseMessageId, _chatId); if (chatEventEmitter) clearInterval(chatEventEmitter); } else { toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); @@ -1190,9 +1323,7 @@ chats.set(await getChatList(localStorage.token, $currentChatPage)); }; - const sendPromptOpenAI = async (model, userPrompt, responseMessageId, _chatId) => { - let _response = null; - + const sendPromptSocket = async (model, responseMessageId, _chatId) => { const responseMessage = history.messages[responseMessageId]; const userMessage = history.messages[responseMessage.parentId]; @@ -1243,7 +1374,6 @@ ); scrollToBottom(); - eventTarget.dispatchEvent( new CustomEvent('chat:start', { detail: { @@ -1253,278 +1383,133 @@ ); await tick(); - try { - const stream = - model?.info?.params?.stream_response ?? - $settings?.params?.stream_response ?? - params?.stream_response ?? - true; + const stream = + model?.info?.params?.stream_response ?? + $settings?.params?.stream_response ?? + params?.stream_response ?? + true; - const [res, controller] = await generateOpenAIChatCompletion( - localStorage.token, - { - stream: stream, - model: model.id, - messages: [ - params?.system || $settings.system || (responseMessage?.userContext ?? null) - ? { - role: 'system', - content: `${promptTemplate( - params?.system ?? $settings?.system ?? '', - $user.name, - $settings?.userLocation - ? await getAndUpdateUserLocation(localStorage.token) - : undefined - )}${ - (responseMessage?.userContext ?? null) - ? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}` - : '' - }` - } - : undefined, - ...createMessagesList(responseMessageId) - ] - .filter((message) => message?.content?.trim()) - .map((message, idx, arr) => ({ - role: message.role, - ...((message.files?.filter((file) => file.type === 'image').length > 0 ?? false) && - message.role === 'user' - ? { - content: [ - { - type: 'text', - text: message?.merged?.content ?? message.content - }, - ...message.files - .filter((file) => file.type === 'image') - .map((file) => ({ - type: 'image_url', - image_url: { - url: file.url - } - })) - ] - } - : { - content: message?.merged?.content ?? message.content - }) - })), - - params: { - ...$settings?.params, - ...params, - - format: $settings.requestFormat ?? undefined, - keep_alive: $settings.keepAlive ?? undefined, - stop: - (params?.stop ?? $settings?.params?.stop ?? undefined) - ? ( - params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop - ).map((str) => - decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) - ) + const messages = [ + params?.system || $settings.system || (responseMessage?.userContext ?? null) + ? { + role: 'system', + content: `${promptTemplate( + params?.system ?? $settings?.system ?? '', + $user.name, + $settings?.userLocation + ? await getAndUpdateUserLocation(localStorage.token) : undefined - }, - - tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, - files: files.length > 0 ? files : undefined, - session_id: $socket?.id, - chat_id: $chatId, - id: responseMessageId, - - ...(stream && (model.info?.meta?.capabilities?.usage ?? false) - ? { - stream_options: { - include_usage: true - } - } - : {}) - }, - `${WEBUI_BASE_URL}/api` - ); - - // Wait until history/message have been updated - await tick(); - - scrollToBottom(); - - if (res && res.ok && res.body) { - if (!stream) { - const response = await res.json(); - console.log(response); - - responseMessage.content = response.choices[0].message.content; - responseMessage.info = { ...response.usage, openai: true }; - responseMessage.done = true; - } else { - const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); - - for await (const update of textStream) { - const { value, done, sources, selectedModelId, error, usage } = update; - if (error) { - await handleOpenAIError(error, null, model, responseMessage); - break; - } - - if (done || stopResponseFlag || _chatId !== $chatId) { - responseMessage.done = true; - history.messages[responseMessageId] = responseMessage; - - if (stopResponseFlag) { - controller.abort('User: Stop Response'); - } - _response = responseMessage.content; - break; - } - - if (usage) { - responseMessage.usage = usage; - } - - if (selectedModelId) { - responseMessage.selectedModelId = selectedModelId; - responseMessage.arena = true; - continue; - } - - if (sources) { - responseMessage.sources = sources; - // Only remove status if it was initially set - if (model?.info?.meta?.knowledge ?? false) { - responseMessage.statusHistory = responseMessage.statusHistory.filter( - (status) => status.action !== 'knowledge_search' - ); - } - continue; - } - - if (responseMessage.content == '' && value == '\n') { - continue; - } else { - responseMessage.content += value; - - if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) { - navigator.vibrate(5); - } - - const messageContentParts = getMessageContentParts( - responseMessage.content, - $config?.audio?.tts?.split_on ?? 'punctuation' - ); - messageContentParts.pop(); - - // dispatch only last sentence and make sure it hasn't been dispatched before - if ( - messageContentParts.length > 0 && - messageContentParts[messageContentParts.length - 1] !== responseMessage.lastSentence - ) { - responseMessage.lastSentence = messageContentParts[messageContentParts.length - 1]; - eventTarget.dispatchEvent( - new CustomEvent('chat', { - detail: { - id: responseMessageId, - content: messageContentParts[messageContentParts.length - 1] - } - }) - ); - } - - history.messages[responseMessageId] = responseMessage; - } - - if (autoScroll) { - scrollToBottom(); - } + )}${ + (responseMessage?.userContext ?? null) + ? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}` + : '' + }` } - } + : undefined, + ...createMessagesList(responseMessageId) + ] + .filter((message) => message?.content?.trim()) + .map((message, idx, arr) => ({ + role: message.role, + ...((message.files?.filter((file) => file.type === 'image').length > 0 ?? false) && + message.role === 'user' + ? { + content: [ + { + type: 'text', + text: message?.merged?.content ?? message.content + }, + ...message.files + .filter((file) => file.type === 'image') + .map((file) => ({ + type: 'image_url', + image_url: { + url: file.url + } + })) + ] + } + : { + content: message?.merged?.content ?? message.content + }) + })); - if ($settings.notificationEnabled && !document.hasFocus()) { - const notification = new Notification(`${model.id}`, { - body: responseMessage.content, - icon: `${WEBUI_BASE_URL}/static/favicon.png` - }); - } + const res = await generateOpenAIChatCompletion( + localStorage.token, + { + stream: stream, + model: model.id, + messages: messages, + params: { + ...$settings?.params, + ...params, - if ($settings.responseAutoCopy) { - copyToClipboard(responseMessage.content); - } + format: $settings.requestFormat ?? undefined, + keep_alive: $settings.keepAlive ?? undefined, + stop: + (params?.stop ?? $settings?.params?.stop ?? undefined) + ? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map( + (str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) + ) + : undefined + }, - if ($settings.responseAutoPlayback && !$showCallOverlay) { - await tick(); + tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, + files: files.length > 0 ? files : undefined, + session_id: $socket?.id, + chat_id: $chatId, + id: responseMessageId, - document.getElementById(`speak-button-${responseMessage.id}`)?.click(); - } - } else { - await handleOpenAIError(null, res, model, responseMessage); - } - } catch (error) { - await handleOpenAIError(error, null, model, responseMessage); + ...(!$temporaryChatEnabled && messages.length == 1 && selectedModels[0] === model.id + ? { + background_tasks: { + title_generation: $settings?.title?.auto ?? true, + tags_generation: $settings?.autoTags ?? true + } + } + : {}), + + ...(stream && (model.info?.meta?.capabilities?.usage ?? false) + ? { + stream_options: { + include_usage: true + } + } + : {}) + }, + `${WEBUI_BASE_URL}/api` + ).catch((error) => { + responseMessage.error = { + content: error + }; + responseMessage.done = true; + return null; + }); + + console.log(res); + + if (res) { + taskId = res.task_id; } - await saveChatHandler(_chatId); - - history.messages[responseMessageId] = responseMessage; - - await chatCompletedHandler( - _chatId, - model.id, - responseMessageId, - createMessagesList(responseMessageId) - ); - - stopResponseFlag = false; + // Wait until history/message have been updated await tick(); + scrollToBottom(); - let lastMessageContentPart = - getMessageContentParts( - responseMessage.content, - $config?.audio?.tts?.split_on ?? 'punctuation' - )?.at(-1) ?? ''; - if (lastMessageContentPart) { - eventTarget.dispatchEvent( - new CustomEvent('chat', { - detail: { id: responseMessageId, content: lastMessageContentPart } - }) - ); - } - - eventTarget.dispatchEvent( - new CustomEvent('chat:finish', { - detail: { - id: responseMessageId, - content: responseMessage.content - } - }) - ); - - if (autoScroll) { - scrollToBottom(); - } - - const messages = createMessagesList(responseMessageId); - if (messages.length == 2 && selectedModels[0] === model.id) { - window.history.replaceState(history.state, '', `/c/${_chatId}`); - - const title = await generateChatTitle(messages); - await setChatTitle(_chatId, title); - - if ($settings?.autoTags ?? true) { - await setChatTags(messages); - } - } - - return _response; + // if ($settings?.autoTags ?? true) { + // await setChatTags(messages); + // } + // } }; - const handleOpenAIError = async (error, res: Response | null, model, responseMessage) => { + const handleOpenAIError = async (error, responseMessage) => { let errorMessage = ''; let innerError; if (error) { innerError = error; - } else if (res !== null) { - innerError = await res.json(); } + console.error(innerError); if ('detail' in innerError) { toast.error(innerError.detail); @@ -1543,12 +1528,7 @@ } responseMessage.error = { - content: - $i18n.t(`Uh-oh! There was an issue connecting to {{provider}}.`, { - provider: model.name ?? model.id - }) + - '\n' + - errorMessage + content: $i18n.t(`Uh-oh! There was an issue with the response.`) + '\n' + errorMessage }; responseMessage.done = true; @@ -1562,8 +1542,15 @@ }; const stopResponse = () => { - stopResponseFlag = true; - console.log('stopResponse'); + if (taskId) { + const res = stopTask(localStorage.token, taskId).catch((error) => { + return null; + }); + + if (res) { + taskId = null; + } + } }; const submitMessage = async (parentId, prompt) => { @@ -1628,12 +1615,7 @@ .at(0); if (model) { - await sendPromptOpenAI( - model, - history.messages[responseMessage.parentId].content, - responseMessage.id, - _chatId - ); + await sendPromptSocket(model, responseMessage.id, _chatId); } } }; @@ -1685,38 +1667,6 @@ } }; - const generateChatTitle = async (messages) => { - const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); - - if ($settings?.title?.auto ?? true) { - const modelId = selectedModels[0]; - - const title = await generateTitle(localStorage.token, modelId, messages, $chatId).catch( - (error) => { - console.error(error); - return lastUserMessage?.content ?? 'New Chat'; - } - ); - - return title ? title : (lastUserMessage?.content ?? 'New Chat'); - } else { - return lastUserMessage?.content ?? 'New Chat'; - } - }; - - const setChatTitle = async (_chatId, title) => { - if (_chatId === $chatId) { - chatTitle.set(title); - } - - if (!$temporaryChatEnabled) { - chat = await updateChatById(localStorage.token, _chatId, { title: title }); - - currentChatPage.set(1); - await chats.set(await getChatList(localStorage.token, $currentChatPage)); - } - }; - const setChatTags = async (messages) => { if (!$temporaryChatEnabled) { const currentTags = await getTagsById(localStorage.token, $chatId); @@ -1856,6 +1806,8 @@ currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); await chatId.set(chat.id); + + window.history.replaceState(history.state, '', `/c/${chat.id}`); } else { await chatId.set('local'); }