import json import time import uuid from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db from open_webui.apps.webui.models.tags import TagModel, Tag, Tags from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import or_, func, select, and_, text from sqlalchemy.sql import exists #################### # Chat DB Schema #################### class Chat(Base): __tablename__ = "chat" id = Column(String, primary_key=True) user_id = Column(String) title = Column(Text) chat = Column(JSON) created_at = Column(BigInteger) updated_at = Column(BigInteger) share_id = Column(Text, unique=True, nullable=True) archived = Column(Boolean, default=False) pinned = Column(Boolean, default=False, nullable=True) meta = Column(JSON, server_default="{}") folder_id = Column(Text, nullable=True) class ChatModel(BaseModel): model_config = ConfigDict(from_attributes=True) id: str user_id: str title: str chat: dict created_at: int # timestamp in epoch updated_at: int # timestamp in epoch share_id: Optional[str] = None archived: bool = False pinned: Optional[bool] = False meta: dict = {} folder_id: Optional[str] = None #################### # Forms #################### class ChatForm(BaseModel): chat: dict class ChatImportForm(ChatForm): pinned: Optional[bool] = False folder_id: Optional[str] = None class ChatTitleMessagesForm(BaseModel): title: str messages: list[dict] class ChatTitleForm(BaseModel): title: str class ChatResponse(BaseModel): id: str user_id: str title: str chat: dict updated_at: int # timestamp in epoch created_at: int # timestamp in epoch share_id: Optional[str] = None # id of the chat to be shared archived: bool pinned: Optional[bool] = False meta: dict = {} folder_id: Optional[str] = None class ChatTitleIdResponse(BaseModel): id: str title: str updated_at: int created_at: int class ChatTable: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: with get_db() as db: id = str(uuid.uuid4()) chat = ChatModel( **{ "id": id, "user_id": user_id, "title": ( form_data.chat["title"] if "title" in form_data.chat else "New Chat" ), "chat": form_data.chat, "created_at": int(time.time()), "updated_at": int(time.time()), } ) result = Chat(**chat.model_dump()) db.add(result) db.commit() db.refresh(result) return ChatModel.model_validate(result) if result else None def import_chat( self, user_id: str, form_data: ChatImportForm ) -> Optional[ChatModel]: with get_db() as db: id = str(uuid.uuid4()) chat = ChatModel( **{ "id": id, "user_id": user_id, "title": ( form_data.chat["title"] if "title" in form_data.chat else "New Chat" ), "chat": form_data.chat, "pinned": form_data.pinned, "folder_id": form_data.folder_id, "created_at": int(time.time()), "updated_at": int(time.time()), } ) result = Chat(**chat.model_dump()) db.add(result) db.commit() db.refresh(result) return ChatModel.model_validate(result) if result else None def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: with get_db() as db: chat_item = db.get(Chat, id) chat_item.chat = chat chat_item.title = chat["title"] if "title" in chat else "New Chat" chat_item.updated_at = int(time.time()) db.commit() db.refresh(chat_item) return ChatModel.model_validate(chat_item) except Exception: return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: with get_db() as db: # Get the existing chat to share chat = db.get(Chat, chat_id) # Check if the chat is already shared if chat.share_id: return self.get_chat_by_id_and_user_id(chat.share_id, "shared") # Create a new chat with the same data, but with a new ID shared_chat = ChatModel( **{ "id": str(uuid.uuid4()), "user_id": f"shared-{chat_id}", "title": chat.title, "chat": chat.chat, "created_at": chat.created_at, "updated_at": int(time.time()), } ) shared_result = Chat(**shared_chat.model_dump()) db.add(shared_result) db.commit() db.refresh(shared_result) # Update the original chat with the share_id result = ( db.query(Chat) .filter_by(id=chat_id) .update({"share_id": shared_chat.id}) ) db.commit() return shared_chat if (shared_result and result) else None def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: with get_db() as db: print("update_shared_chat_by_id") chat = db.get(Chat, chat_id) print(chat) chat.title = chat.title chat.chat = chat.chat db.commit() db.refresh(chat) return self.get_chat_by_id(chat.share_id) except Exception: return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: with get_db() as db: db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.commit() return True except Exception: return False def update_chat_share_id_by_id( self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) chat.share_id = share_id db.commit() db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) chat.pinned = not chat.pinned chat.updated_at = int(time.time()) db.commit() db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) chat.archived = not chat.archived chat.updated_at = int(time.time()) db.commit() db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: with get_db() as db: db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) db.commit() return True except Exception: return False def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> list[ChatModel]: with get_db() as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc()) # .limit(limit).offset(skip) .all() ) return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, user_id: str, include_archived: bool = False, skip: int = 0, limit: int = 50, ) -> list[ChatModel]: with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) if not include_archived: query = query.filter_by(archived=False) query = query.order_by(Chat.updated_at.desc()) if skip: query = query.offset(skip) if limit: query = query.limit(limit) all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_title_id_list_by_user_id( self, user_id: str, include_archived: bool = False, skip: Optional[int] = None, limit: Optional[int] = None, ) -> list[ChatTitleIdResponse]: with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) if not include_archived: query = query.filter_by(archived=False) query = query.order_by(Chat.updated_at.desc()).with_entities( Chat.id, Chat.title, Chat.updated_at, Chat.created_at ) if skip: query = query.offset(skip) if limit: query = query.limit(limit) all_chats = query.all() # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. return [ ChatTitleIdResponse.model_validate( { "id": chat[0], "title": chat[1], "updated_at": chat[2], "created_at": chat[3], } ) for chat in all_chats ] def get_chat_list_by_chat_ids( self, chat_ids: list[str], skip: int = 0, limit: int = 50 ) -> list[ChatModel]: with get_db() as db: all_chats = ( db.query(Chat) .filter(Chat.id.in_(chat_ids)) .filter_by(archived=False) .order_by(Chat.updated_at.desc()) .all() ) return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) return ChatModel.model_validate(chat) except Exception: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.query(Chat).filter_by(share_id=id).first() if chat: return self.get_chat_by_id(id) else: return None except Exception: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat) except Exception: return None def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: with get_db() as db: all_chats = ( db.query(Chat) # .limit(limit).offset(skip) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id, pinned=True, archived=False) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id_and_search_text( self, user_id: str, search_text: str, include_archived: bool = False, skip: int = 0, limit: int = 60, ) -> list[ChatModel]: """ Filters chats based on a search query using Python, allowing pagination using skip and limit. """ search_text = search_text.lower().strip() if not search_text: return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) search_text_words = search_text.split(" ") # search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags tag_ids = [ word.replace("tag:", "").replace(" ", "_").lower() for word in search_text_words if word.startswith("tag:") ] search_text_words = [ word for word in search_text_words if not word.startswith("tag:") ] search_text = " ".join(search_text_words) with get_db() as db: query = db.query(Chat).filter(Chat.user_id == user_id) if not include_archived: query = query.filter(Chat.archived == False) query = query.order_by(Chat.updated_at.desc()) # Check if the database dialect is either 'sqlite' or 'postgresql' dialect_name = db.bind.dialect.name if dialect_name == "sqlite": # SQLite case: using JSON1 extension for JSON searching query = query.filter( ( Chat.title.ilike( f"%{search_text}%" ) # Case-insensitive search in title | text( """ EXISTS ( SELECT 1 FROM json_each(Chat.chat, '$.messages') AS message WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%' ) """ ) ).params(search_text=search_text) ) # Check if there are any tags to filter, it should have all the tags if "none" in tag_ids: query = query.filter( text( """ NOT EXISTS ( SELECT 1 FROM json_each(Chat.meta, '$.tags') AS tag ) """ ) ) elif tag_ids: query = query.filter( and_( *[ text( f""" EXISTS ( SELECT 1 FROM json_each(Chat.meta, '$.tags') AS tag WHERE tag.value = :tag_id_{tag_idx} ) """ ).params(**{f"tag_id_{tag_idx}": tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) ) elif dialect_name == "postgresql": # PostgreSQL relies on proper JSON query for search query = query.filter( ( Chat.title.ilike( f"%{search_text}%" ) # Case-insensitive search in title | text( """ EXISTS ( SELECT 1 FROM json_array_elements(Chat.chat->'messages') AS message WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%' ) """ ) ).params(search_text=search_text) ) # Check if there are any tags to filter, it should have all the tags if "none" in tag_ids: query = query.filter( text( """ NOT EXISTS ( SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') AS tag ) """ ) ) elif tag_ids: query = query.filter( and_( *[ text( f""" EXISTS ( SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') AS tag WHERE tag = :tag_id_{tag_idx} ) """ ).params(**{f"tag_id_{tag_idx}": tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) ) else: raise NotImplementedError( f"Unsupported dialect: {db.bind.dialect.name}" ) # Perform pagination at the SQL level all_chats = query.offset(skip).limit(limit).all() print(len(all_chats)) # Validate and return chats return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_id_and_user_id( self, folder_id: str, user_id: str ) -> list[ChatModel]: with get_db() as db: query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter_by(archived=False) query = query.order_by(Chat.updated_at.desc()) all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_ids_and_user_id( self, folder_ids: list[str], user_id: str ) -> list[ChatModel]: with get_db() as db: query = db.query(Chat).filter( Chat.folder_id.in_(folder_ids), Chat.user_id == user_id ) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter_by(archived=False) query = query.order_by(Chat.updated_at.desc()) all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] def update_chat_folder_id_by_id_and_user_id( self, id: str, user_id: str, folder_id: str ) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) chat.folder_id = folder_id chat.updated_at = int(time.time()) chat.pinned = False db.commit() db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: with get_db() as db: chat = db.get(Chat, id) tags = chat.meta.get("tags", []) return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] def get_chat_list_by_user_id_and_tag_name( self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 ) -> list[ChatModel]: with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id) tag_id = tag_name.replace(" ", "_").lower() print(db.bind.dialect.name) if db.bind.dialect.name == "sqlite": # SQLite JSON1 querying for tags within the meta JSON field query = query.filter( text( f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" ) ).params(tag_id=tag_id) elif db.bind.dialect.name == "postgresql": # PostgreSQL JSON query for tags within the meta JSON field (for `json` type) query = query.filter( text( "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" ) ).params(tag_id=tag_id) else: raise NotImplementedError( f"Unsupported dialect: {db.bind.dialect.name}" ) all_chats = query.all() print("all_chats", all_chats) return [ChatModel.model_validate(chat) for chat in all_chats] def add_chat_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str ) -> Optional[ChatModel]: tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) if tag is None: tag = Tags.insert_new_tag(tag_name, user_id) try: with get_db() as db: chat = db.get(Chat, id) tag_id = tag.id if tag_id not in chat.meta.get("tags", []): chat.meta = { **chat.meta, "tags": list(set(chat.meta.get("tags", []) + [tag_id])), } db.commit() db.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: with get_db() as db: # Assuming `get_db()` returns a session object query = db.query(Chat).filter_by(user_id=user_id, archived=False) # Normalize the tag_name for consistency tag_id = tag_name.replace(" ", "_").lower() if db.bind.dialect.name == "sqlite": # SQLite JSON1 support for querying the tags inside the `meta` JSON field query = query.filter( text( f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" ) ).params(tag_id=tag_id) elif db.bind.dialect.name == "postgresql": # PostgreSQL JSONB support for querying the tags inside the `meta` JSON field query = query.filter( text( "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" ) ).params(tag_id=tag_id) else: raise NotImplementedError( f"Unsupported dialect: {db.bind.dialect.name}" ) # Get the count of matching records count = query.count() # Debugging output for inspection print(f"Count of chats for tag '{tag_name}':", count) return count def delete_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str ) -> bool: try: with get_db() as db: chat = db.get(Chat, id) tags = chat.meta.get("tags", []) tag_id = tag_name.replace(" ", "_").lower() tags = [tag for tag in tags if tag != tag_id] chat.meta = { **chat.meta, "tags": list(set(tags)), } db.commit() return True except Exception: return False def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: with get_db() as db: chat = db.get(Chat, id) chat.meta = { **chat.meta, "tags": [], } db.commit() return True except Exception: return False def delete_chat_by_id(self, id: str) -> bool: try: with get_db() as db: db.query(Chat).filter_by(id=id).delete() db.commit() return True and self.delete_shared_chat_by_chat_id(id) except Exception: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: with get_db() as db: db.query(Chat).filter_by(id=id, user_id=user_id).delete() db.commit() return True and self.delete_shared_chat_by_chat_id(id) except Exception: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: with get_db() as db: self.delete_shared_chats_by_user_id(user_id) db.query(Chat).filter_by(user_id=user_id).delete() db.commit() return True except Exception: return False def delete_chats_by_user_id_and_folder_id( self, user_id: str, folder_id: str ) -> bool: try: with get_db() as db: db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() db.commit() return True except Exception: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: with get_db() as db: chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() db.commit() return True except Exception: return False Chats = ChatTable()