mirror of
https://github.com/open-webui/open-webui
synced 2025-04-08 06:35:04 +00:00
refac: convert chat.chat to json data type
This commit is contained in:
parent
a04f22d55f
commit
d7a00af576
@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from open_webui.apps.webui.internal.db import Base, get_db
|
from open_webui.apps.webui.internal.db import Base, get_db
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Chat DB Schema
|
# Chat DB Schema
|
||||||
@ -18,7 +18,7 @@ class Chat(Base):
|
|||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
title = Column(Text)
|
title = Column(Text)
|
||||||
chat = Column(Text) # Save Chat JSON as Text
|
chat = Column(JSON)
|
||||||
|
|
||||||
created_at = Column(BigInteger)
|
created_at = Column(BigInteger)
|
||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
@ -33,7 +33,7 @@ class ChatModel(BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
title: str
|
title: str
|
||||||
chat: str
|
chat: dict
|
||||||
|
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
@ -86,7 +86,7 @@ class ChatTable:
|
|||||||
if "title" in form_data.chat
|
if "title" in form_data.chat
|
||||||
else "New Chat"
|
else "New Chat"
|
||||||
),
|
),
|
||||||
"chat": json.dumps(form_data.chat),
|
"chat": form_data.chat,
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
}
|
}
|
||||||
@ -101,14 +101,14 @@ class ChatTable:
|
|||||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chat_obj = db.get(Chat, id)
|
chat_item = db.get(Chat, id)
|
||||||
chat_obj.chat = json.dumps(chat)
|
chat_item.chat = chat
|
||||||
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
|
chat_item.title = chat["title"] if "title" in chat else "New Chat"
|
||||||
chat_obj.updated_at = int(time.time())
|
chat_item.updated_at = int(time.time())
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(chat_obj)
|
db.refresh(chat_item)
|
||||||
|
|
||||||
return ChatModel.model_validate(chat_obj)
|
return ChatModel.model_validate(chat_item)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ async def get_user_chat_list_by_user_id(
|
|||||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||||
try:
|
try:
|
||||||
chat = Chats.insert_new_chat(user.id, form_data)
|
chat = Chats.insert_new_chat(user.id, form_data)
|
||||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
return ChatResponse(**chat.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -111,7 +111,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
|||||||
@router.get("/all", response_model=list[ChatResponse])
|
@router.get("/all", response_model=list[ChatResponse])
|
||||||
async def get_user_chats(user=Depends(get_verified_user)):
|
async def get_user_chats(user=Depends(get_verified_user)):
|
||||||
return [
|
return [
|
||||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
ChatResponse(**chat.model_dump())
|
||||||
for chat in Chats.get_chats_by_user_id(user.id)
|
for chat in Chats.get_chats_by_user_id(user.id)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
|
|||||||
@router.get("/all/archived", response_model=list[ChatResponse])
|
@router.get("/all/archived", response_model=list[ChatResponse])
|
||||||
async def get_user_archived_chats(user=Depends(get_verified_user)):
|
async def get_user_archived_chats(user=Depends(get_verified_user)):
|
||||||
return [
|
return [
|
||||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
ChatResponse(**chat.model_dump())
|
||||||
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -141,10 +141,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
|||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
return [
|
return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
|
||||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
|
||||||
for chat in Chats.get_chats()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
@ -187,7 +184,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
|
|||||||
chat = Chats.get_chat_by_id(share_id)
|
chat = Chats.get_chat_by_id(share_id)
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
return ChatResponse(**chat.model_dump())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
@ -251,7 +249,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
return ChatResponse(**chat.model_dump())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
@ -269,10 +268,9 @@ async def update_chat_by_id(
|
|||||||
):
|
):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
if chat:
|
if chat:
|
||||||
updated_chat = {**json.loads(chat.chat), **form_data.chat}
|
updated_chat = {**chat.chat, **form_data.chat}
|
||||||
|
|
||||||
chat = Chats.update_chat_by_id(id, updated_chat)
|
chat = Chats.update_chat_by_id(id, updated_chat)
|
||||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -312,16 +310,15 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
|||||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
if chat:
|
if chat:
|
||||||
chat_body = json.loads(chat.chat)
|
|
||||||
updated_chat = {
|
updated_chat = {
|
||||||
**chat_body,
|
**chat.chat,
|
||||||
"originalChatId": chat.id,
|
"originalChatId": chat.id,
|
||||||
"branchPointMessageId": chat_body["history"]["currentId"],
|
"branchPointMessageId": chat.chat["history"]["currentId"],
|
||||||
"title": f"Clone of {chat.title}",
|
"title": f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
@ -338,7 +335,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||||
if chat:
|
if chat:
|
||||||
chat = Chats.toggle_chat_archive_by_id(id)
|
chat = Chats.toggle_chat_archive_by_id(id)
|
||||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
return ChatResponse(**chat.model_dump())
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
@ -356,9 +353,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
if chat:
|
if chat:
|
||||||
if chat.share_id:
|
if chat.share_id:
|
||||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
||||||
return ChatResponse(
|
return ChatResponse(**shared_chat.model_dump())
|
||||||
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
|
|
||||||
)
|
|
||||||
|
|
||||||
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
|
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
|
||||||
if not shared_chat:
|
if not shared_chat:
|
||||||
@ -366,10 +361,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(),
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
)
|
)
|
||||||
|
return ChatResponse(**shared_chat.model_dump())
|
||||||
|
|
||||||
return ChatResponse(
|
|
||||||
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
@ -0,0 +1,82 @@
|
|||||||
|
"""Update chat table
|
||||||
|
|
||||||
|
Revision ID: 242a2047eae0
|
||||||
|
Revises: 6a39f3d8e55c
|
||||||
|
Create Date: 2024-10-09 21:02:35.241684
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.sql import table, select, update
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
revision = "242a2047eae0"
|
||||||
|
down_revision = "6a39f3d8e55c"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# Step 1: Rename current 'chat' column to 'old_chat'
|
||||||
|
op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text)
|
||||||
|
|
||||||
|
# Step 2: Add new 'chat' column of type JSON
|
||||||
|
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
# Step 3: Migrate data from 'old_chat' to 'chat'
|
||||||
|
chat_table = table(
|
||||||
|
"chat",
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column("old_chat", sa.Text),
|
||||||
|
sa.Column("chat", sa.JSON()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# - Selecting all data from the table
|
||||||
|
connection = op.get_bind()
|
||||||
|
results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat))
|
||||||
|
for row in results:
|
||||||
|
try:
|
||||||
|
# Convert text JSON to actual JSON object, assuming the text is in JSON format
|
||||||
|
json_data = json.loads(row.old_chat)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
json_data = None # Handle cases where the text cannot be converted to JSON
|
||||||
|
|
||||||
|
connection.execute(
|
||||||
|
sa.update(chat_table)
|
||||||
|
.where(chat_table.c.id == row.id)
|
||||||
|
.values(chat=json_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Drop 'old_chat' column
|
||||||
|
op.drop_column("chat", "old_chat")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# Step 1: Add 'old_chat' column back as Text
|
||||||
|
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
|
||||||
|
chat_table = table(
|
||||||
|
"chat",
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column("chat", sa.JSON()),
|
||||||
|
sa.Column("old_chat", sa.Text()),
|
||||||
|
)
|
||||||
|
|
||||||
|
connection = op.get_bind()
|
||||||
|
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
|
||||||
|
for row in results:
|
||||||
|
text_data = json.dumps(row.chat) if row.chat is not None else None
|
||||||
|
connection.execute(
|
||||||
|
sa.update(chat_table)
|
||||||
|
.where(chat_table.c.id == row.id)
|
||||||
|
.values(old_chat=text_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Remove the new 'chat' JSON column
|
||||||
|
op.drop_column("chat", "chat")
|
||||||
|
|
||||||
|
# Step 4: Rename 'old_chat' back to 'chat'
|
||||||
|
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text)
|
Loading…
Reference in New Issue
Block a user