From d7a00af5768ad6058a0a86b4e3b025c3a0e1e628 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 8 Oct 2024 22:02:48 -0700 Subject: [PATCH] refac: convert chat.chat to json data type --- backend/open_webui/apps/webui/models/chats.py | 20 ++--- .../open_webui/apps/webui/routers/chats.py | 39 ++++----- .../242a2047eae0_update_chat_table.py | 82 +++++++++++++++++++ 3 files changed, 108 insertions(+), 33 deletions(-) create mode 100644 backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index f364dcc70..6a79b6ae6 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -5,7 +5,7 @@ from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON #################### # Chat DB Schema @@ -18,7 +18,7 @@ class Chat(Base): id = Column(String, primary_key=True) user_id = Column(String) title = Column(Text) - chat = Column(Text) # Save Chat JSON as Text + chat = Column(JSON) created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -33,7 +33,7 @@ class ChatModel(BaseModel): id: str user_id: str title: str - chat: str + chat: dict created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -86,7 +86,7 @@ class ChatTable: if "title" in form_data.chat else "New Chat" ), - "chat": json.dumps(form_data.chat), + "chat": form_data.chat, "created_at": int(time.time()), "updated_at": int(time.time()), } @@ -101,14 +101,14 @@ class ChatTable: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: with get_db() as db: - chat_obj = db.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) + chat_item = db.get(Chat, id) + chat_item.chat = chat + chat_item.title = chat["title"] if "title" in chat else "New Chat" + chat_item.updated_at = int(time.time()) db.commit() - db.refresh(chat_obj) + db.refresh(chat_item) - return ChatModel.model_validate(chat_obj) + return ChatModel.model_validate(chat_item) except Exception: return None diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index ca7e95baf..01d99cfd8 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -95,7 +95,7 @@ async def get_user_chat_list_by_user_id( async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): try: chat = Chats.insert_new_chat(user.id, form_data) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) except Exception as e: log.exception(e) raise HTTPException( @@ -111,7 +111,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): @router.get("/all", response_model=list[ChatResponse]) async def get_user_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + ChatResponse(**chat.model_dump()) for chat in Chats.get_chats_by_user_id(user.id) ] @@ -124,7 +124,7 @@ async def get_user_chats(user=Depends(get_verified_user)): @router.get("/all/archived", response_model=list[ChatResponse]) async def get_user_archived_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + ChatResponse(**chat.model_dump()) for chat in Chats.get_archived_chats_by_user_id(user.id) ] @@ -141,10 +141,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_chats() - ] + return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] ############################ @@ -187,7 +184,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id(share_id) if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -251,7 +249,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -269,10 +268,9 @@ async def update_chat_by_id( ): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - updated_chat = {**json.loads(chat.chat), **form_data.chat} - + updated_chat = {**chat.chat, **form_data.chat} chat = Chats.update_chat_by_id(id, updated_chat) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -312,16 +310,15 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat_body = json.loads(chat.chat) updated_chat = { - **chat_body, + **chat.chat, "originalChatId": chat.id, - "branchPointMessageId": chat_body["history"]["currentId"], + "branchPointMessageId": chat.chat["history"]["currentId"], "title": f"Clone of {chat.title}", } chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -338,7 +335,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat = Chats.toggle_chat_archive_by_id(id) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -356,9 +353,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)): if chat: if chat.share_id: shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) - return ChatResponse( - **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} - ) + return ChatResponse(**shared_chat.model_dump()) shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) if not shared_chat: @@ -366,10 +361,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(), ) + return ChatResponse(**shared_chat.model_dump()) - return ChatResponse( - **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} - ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py new file mode 100644 index 000000000..596703dc2 --- /dev/null +++ b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py @@ -0,0 +1,82 @@ +"""Update chat table + +Revision ID: 242a2047eae0 +Revises: 6a39f3d8e55c +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update + +import json + +revision = "242a2047eae0" +down_revision = "6a39f3d8e55c" +branch_labels = None +depends_on = None + + +def upgrade(): + # Step 1: Rename current 'chat' column to 'old_chat' + op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text) + + # Step 2: Add new 'chat' column of type JSON + op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True)) + + # Step 3: Migrate data from 'old_chat' to 'chat' + chat_table = table( + "chat", + sa.Column("id", sa.String, primary_key=True), + sa.Column("old_chat", sa.Text), + sa.Column("chat", sa.JSON()), + ) + + # - Selecting all data from the table + connection = op.get_bind() + results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat)) + for row in results: + try: + # Convert text JSON to actual JSON object, assuming the text is in JSON format + json_data = json.loads(row.old_chat) + except json.JSONDecodeError: + json_data = None # Handle cases where the text cannot be converted to JSON + + connection.execute( + sa.update(chat_table) + .where(chat_table.c.id == row.id) + .values(chat=json_data) + ) + + # Step 4: Drop 'old_chat' column + op.drop_column("chat", "old_chat") + + +def downgrade(): + # Step 1: Add 'old_chat' column back as Text + op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True)) + + # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat' + chat_table = table( + "chat", + sa.Column("id", sa.String, primary_key=True), + sa.Column("chat", sa.JSON()), + sa.Column("old_chat", sa.Text()), + ) + + connection = op.get_bind() + results = connection.execute(select(chat_table.c.id, chat_table.c.chat)) + for row in results: + text_data = json.dumps(row.chat) if row.chat is not None else None + connection.execute( + sa.update(chat_table) + .where(chat_table.c.id == row.id) + .values(old_chat=text_data) + ) + + # Step 3: Remove the new 'chat' JSON column + op.drop_column("chat", "chat") + + # Step 4: Rename 'old_chat' back to 'chat' + op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text)