From 1be6ad12504a24d1fe4dfd4053d74ac77d29792d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 26 Mar 2025 01:10:27 -0700 Subject: [PATCH] feat: `/messages` chat endpoint support --- backend/open_webui/routers/chats.py | 58 +++++++++++++++++++++++++++++ backend/open_webui/socket/main.py | 46 ++++++++++++----------- src/lib/components/chat/Chat.svelte | 12 +----- 3 files changed, 84 insertions(+), 32 deletions(-) diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 2efd043ef..ec73b8662 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -2,6 +2,8 @@ import json import logging from typing import Optional + +from open_webui.socket.main import get_event_emitter from open_webui.models.chats import ( ChatForm, ChatImportForm, @@ -372,6 +374,62 @@ async def update_chat_by_id( ) +############################ +# UpdateChatMessageById +############################ +class MessageForm(BaseModel): + content: str + + +@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse]) +async def update_chat_message_by_id( + id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) +): + chat = Chats.get_chat_by_id(id) + + if not chat: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + if chat.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + chat = Chats.upsert_message_to_chat_by_id_and_message_id( + id, + message_id, + { + "content": form_data.content, + }, + ) + + event_emitter = get_event_emitter( + { + "user_id": user.id, + "chat_id": id, + "message_id": message_id, + } + ) + + if event_emitter: + event_emitter( + { + "type": "chat:message", + "data": { + "chat_id": id, + "message_id": message_id, + "content": form_data.content, + }, + } + ) + + return ChatResponse(**chat.model_dump()) + + ############################ # DeleteChatById ############################ diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 8f5a9568b..9fea25ea0 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -269,7 +269,7 @@ async def disconnect(sid): # print(f"Unknown session ID {sid} disconnected") -def get_event_emitter(request_info): +def get_event_emitter(request_info, update_db=True): async def __event_emitter__(event_data): user_id = request_info["user_id"] session_ids = list( @@ -287,31 +287,33 @@ def get_event_emitter(request_info): to=session_id, ) - if "type" in event_data and event_data["type"] == "status": - Chats.add_message_status_to_chat_by_id_and_message_id( - request_info["chat_id"], - request_info["message_id"], - event_data.get("data", {}), - ) - if "type" in event_data and event_data["type"] == "message": - message = Chats.get_message_by_id_and_message_id( - request_info["chat_id"], - request_info["message_id"], - ) + if update_db: + if "type" in event_data and event_data["type"] == "status": + Chats.add_message_status_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + event_data.get("data", {}), + ) - content = message.get("content", "") - content += event_data.get("data", {}).get("content", "") + if "type" in event_data and event_data["type"] == "message": + message = Chats.get_message_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + ) - Chats.upsert_message_to_chat_by_id_and_message_id( - request_info["chat_id"], - request_info["message_id"], - { - "content": content, - }, - ) + content = message.get("content", "") + content += event_data.get("data", {}).get("content", "") - if "type" in event_data and event_data["type"] == "replace": + Chats.upsert_message_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + { + "content": content, + }, + ) + + if "type" in event_data and event_data["type"] == "replace": content = event_data.get("data", {}).get("content", "") Chats.upsert_message_to_chat_by_id_and_message_id( diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 2fbae7cd6..2892d436c 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -293,18 +293,10 @@ } else if (type === 'chat:tags') { chat = await getChatById(localStorage.token, $chatId); allTags.set(await getAllTags(localStorage.token)); - } else if (type === 'message') { + } else if (type === 'chat:message:delta' || type === 'message') { message.content += data.content; - } else if (type === 'replace') { + } else if (type === 'chat:message' || type === 'replace') { message.content = data.content; - } else if (type === 'action') { - if (data.action === 'continue') { - const continueButton = document.getElementById('continue-response-button'); - - if (continueButton) { - continueButton.click(); - } - } } else if (type === 'confirmation') { eventCallback = cb;