refac/pref: chat import optimization
Co-Authored-By: G30 <50341825+silentoplayz@users.noreply.github.com>
This commit is contained in:
@@ -92,6 +92,10 @@ class ChatImportForm(ChatForm):
|
||||
updated_at: Optional[int] = None
|
||||
|
||||
|
||||
class ChatsImportForm(BaseModel):
|
||||
chats: list[ChatImportForm]
|
||||
|
||||
|
||||
class ChatTitleMessagesForm(BaseModel):
|
||||
title: str
|
||||
messages: list[dict]
|
||||
@@ -148,42 +152,44 @@ class ChatTable:
|
||||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
|
||||
def import_chat(
|
||||
def _chat_import_form_to_chat_model(
|
||||
self, user_id: str, form_data: ChatImportForm
|
||||
) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"title": (
|
||||
form_data.chat["title"]
|
||||
if "title" in form_data.chat
|
||||
else "New Chat"
|
||||
),
|
||||
"chat": form_data.chat,
|
||||
"meta": form_data.meta,
|
||||
"pinned": form_data.pinned,
|
||||
"folder_id": form_data.folder_id,
|
||||
"created_at": (
|
||||
form_data.created_at
|
||||
if form_data.created_at
|
||||
else int(time.time())
|
||||
),
|
||||
"updated_at": (
|
||||
form_data.updated_at
|
||||
if form_data.updated_at
|
||||
else int(time.time())
|
||||
),
|
||||
}
|
||||
)
|
||||
) -> ChatModel:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"title": (
|
||||
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
|
||||
),
|
||||
"chat": form_data.chat,
|
||||
"meta": form_data.meta,
|
||||
"pinned": form_data.pinned,
|
||||
"folder_id": form_data.folder_id,
|
||||
"created_at": (
|
||||
form_data.created_at if form_data.created_at else int(time.time())
|
||||
),
|
||||
"updated_at": (
|
||||
form_data.updated_at if form_data.updated_at else int(time.time())
|
||||
),
|
||||
}
|
||||
)
|
||||
return chat
|
||||
|
||||
result = Chat(**chat.model_dump())
|
||||
db.add(result)
|
||||
def import_chats(
|
||||
self, user_id: str, chats: list[ChatImportForm]
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
chats = []
|
||||
|
||||
for form_data in chats:
|
||||
chat = self._chat_import_form_to_chat_model(user_id, form_data)
|
||||
chats.append(Chat(**chat.model_dump()))
|
||||
|
||||
db.add_all(chats)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
return [ChatModel.model_validate(chat) for chat in chats]
|
||||
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
try:
|
||||
|
||||
@@ -7,9 +7,11 @@ from open_webui.socket.main import get_event_emitter
|
||||
from open_webui.models.chats import (
|
||||
ChatForm,
|
||||
ChatImportForm,
|
||||
ChatBulkImportForm,
|
||||
ChatResponse,
|
||||
Chats,
|
||||
ChatTitleIdResponse,
|
||||
ChatsImportForm,
|
||||
)
|
||||
from open_webui.models.tags import TagModel, Tags
|
||||
from open_webui.models.folders import Folders
|
||||
@@ -142,26 +144,15 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
############################
|
||||
# ImportChat
|
||||
# ImportChats
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/import", response_model=Optional[ChatResponse])
|
||||
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
|
||||
async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
chat = Chats.import_chat(user.id, form_data)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
for tag_id in tags:
|
||||
tag_id = tag_id.replace(" ", "_").lower()
|
||||
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
|
||||
if (
|
||||
tag_id != "none"
|
||||
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
|
||||
):
|
||||
Tags.insert_new_tag(tag_name, user.id)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
chats = Chats.import_chats(user.id, form_data.chats)
|
||||
return [ChatResponse(**chat.model_dump()) for chat in chats]
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
|
||||
Reference in New Issue
Block a user