diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index 4109bfa46..e4ad65db2 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -4,10 +4,13 @@ import uuid from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.apps.webui.models.tags import TagModel, Tag, Tags + + from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON -from sqlalchemy import or_, func, select - +from sqlalchemy import or_, func, select, and_, text +from sqlalchemy.sql import exists #################### # Chat DB Schema @@ -27,6 +30,9 @@ class Chat(Base): share_id = Column(Text, unique=True, nullable=True) archived = Column(Boolean, default=False) + pinned = Column(Boolean, default=False, nullable=True) + + meta = Column(JSON, server_default="{}") class ChatModel(BaseModel): @@ -42,6 +48,9 @@ class ChatModel(BaseModel): share_id: Optional[str] = None archived: bool = False + pinned: Optional[bool] = False + + meta: dict = {} #################### @@ -66,6 +75,8 @@ class ChatResponse(BaseModel): created_at: int # timestamp in epoch share_id: Optional[str] = None # id of the chat to be shared archived: bool + pinned: Optional[bool] = False + meta: dict = {} class ChatTitleIdResponse(BaseModel): @@ -184,11 +195,24 @@ class ChatTable: except Exception: return None + def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: + try: + with get_db() as db: + chat = db.get(Chat, id) + chat.pinned = not chat.pinned + chat.updated_at = int(time.time()) + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception: + return None + def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: with get_db() as db: chat = db.get(Chat, id) chat.archived = not chat.archived + chat.updated_at = int(time.time()) db.commit() db.refresh(chat) return ChatModel.model_validate(chat) @@ -330,6 +354,15 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: + with get_db() as db: + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, pinned=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] + def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: with get_db() as db: all_chats = ( @@ -383,6 +416,135 @@ class ChatTable: paginated_chats = filtered_chats[skip : skip + limit] return [ChatModel.model_validate(chat) for chat in paginated_chats] + def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: + with get_db() as db: + chat = db.get(Chat, id) + tags = chat.meta.get("tags", []) + return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] + + def get_chat_list_by_user_id_and_tag_name( + self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 + ) -> list[ChatModel]: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + tag_id = tag_name.replace(" ", "_").lower() + + print(db.bind.dialect.name) + if db.bind.dialect.name == "sqlite": + # SQLite JSON1 querying for tags within the meta JSON field + query = query.filter( + text( + f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" + ) + ).params(tag_id=tag_id) + elif db.bind.dialect.name == "postgresql": + # PostgreSQL JSON query for tags within the meta JSON field (for `json` type) + query = query.filter( + text( + "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" + ) + ).params(tag_id=tag_id) + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + all_chats = query.all() + print("all_chats", all_chats) + return [ChatModel.model_validate(chat) for chat in all_chats] + + def add_chat_tag_by_id_and_user_id_and_tag_name( + self, id: str, user_id: str, tag_name: str + ) -> Optional[ChatModel]: + tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) + if tag is None: + tag = Tags.insert_new_tag(tag_name, user_id) + try: + with get_db() as db: + chat = db.get(Chat, id) + + tag_id = tag.id + if tag_id not in chat.meta.get("tags", []): + chat.meta = { + **chat.meta, + "tags": chat.meta.get("tags", []) + [tag_id], + } + + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) + except Exception: + return None + + def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: + with get_db() as db: # Assuming `get_db()` returns a session object + query = db.query(Chat).filter_by(user_id=user_id) + + # Normalize the tag_name for consistency + tag_id = tag_name.replace(" ", "_").lower() + + if db.bind.dialect.name == "sqlite": + # SQLite JSON1 support for querying the tags inside the `meta` JSON field + query = query.filter( + text( + f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)" + ) + ).params(tag_id=tag_id) + + elif db.bind.dialect.name == "postgresql": + # PostgreSQL JSONB support for querying the tags inside the `meta` JSON field + query = query.filter( + text( + "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)" + ) + ).params(tag_id=tag_id) + + else: + raise NotImplementedError( + f"Unsupported dialect: {db.bind.dialect.name}" + ) + + # Get the count of matching records + count = query.count() + + # Debugging output for inspection + print(f"Count of chats for tag '{tag_name}':", count) + + return count + + def delete_tag_by_id_and_user_id_and_tag_name( + self, id: str, user_id: str, tag_name: str + ) -> bool: + try: + with get_db() as db: + chat = db.get(Chat, id) + tags = chat.meta.get("tags", []) + tag_id = tag_name.replace(" ", "_").lower() + + tags = [tag for tag in tags if tag != tag_id] + chat.meta = { + **chat.meta, + "tags": tags, + } + db.commit() + return True + except Exception: + return False + + def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: + try: + with get_db() as db: + chat = db.get(Chat, id) + chat.meta = { + **chat.meta, + "tags": [], + } + db.commit() + + return True + except Exception: + return False + def delete_chat_by_id(self, id: str) -> bool: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/models/tags.py b/backend/open_webui/apps/webui/models/tags.py index 985273ff1..ef209b565 100644 --- a/backend/open_webui/apps/webui/models/tags.py +++ b/backend/open_webui/apps/webui/models/tags.py @@ -4,53 +4,32 @@ import uuid from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db + + from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, JSON log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) + #################### # Tag DB Schema #################### - - class Tag(Base): __tablename__ = "tag" - id = Column(String, primary_key=True) name = Column(String) user_id = Column(String) - data = Column(Text, nullable=True) - - -class ChatIdTag(Base): - __tablename__ = "chatidtag" - - id = Column(String, primary_key=True) - tag_name = Column(String) - chat_id = Column(String) - user_id = Column(String) - timestamp = Column(BigInteger) + meta = Column(JSON, nullable=True) class TagModel(BaseModel): id: str name: str user_id: str - data: Optional[str] = None - - model_config = ConfigDict(from_attributes=True) - - -class ChatIdTagModel(BaseModel): - id: str - tag_name: str - chat_id: str - user_id: str - timestamp: int - + meta: Optional[dict] = None model_config = ConfigDict(from_attributes=True) @@ -59,23 +38,15 @@ class ChatIdTagModel(BaseModel): #################### -class ChatIdTagForm(BaseModel): - tag_name: str +class TagChatIdForm(BaseModel): + name: str chat_id: str -class TagChatIdsResponse(BaseModel): - chat_ids: list[str] - - -class ChatTagsResponse(BaseModel): - tags: list[str] - - class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: with get_db() as db: - id = str(uuid.uuid4()) + id = name.replace(" ", "_").lower() tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: result = Tag(**tag.model_dump()) @@ -93,170 +64,38 @@ class TagTable: self, name: str, user_id: str ) -> Optional[TagModel]: try: + id = name.replace(" ", "_").lower() with get_db() as db: - tag = db.query(Tag).filter_by(name=name, user_id=user_id).first() + tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def add_tag_to_chat( - self, user_id: str, form_data: ChatIdTagForm - ) -> Optional[ChatIdTagModel]: - tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) - if tag is None: - tag = self.insert_new_tag(form_data.tag_name, user_id) - - id = str(uuid.uuid4()) - chatIdTag = ChatIdTagModel( - **{ - "id": id, - "user_id": user_id, - "chat_id": form_data.chat_id, - "tag_name": tag.name, - "timestamp": int(time.time()), - } - ) - try: - with get_db() as db: - result = ChatIdTag(**chatIdTag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None - except Exception: - return None - def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: with get_db() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - return [ TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) + for tag in (db.query(Tag).filter_by(user_id=user_id).all()) ] - def get_tags_by_chat_id_and_user_id( - self, chat_id: str, user_id: str - ) -> list[TagModel]: + def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]: with get_db() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - return [ TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) + for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all()) ] - def get_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> list[ChatIdTagModel]: - with get_db() as db: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] - - def count_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> int: - with get_db() as db: - return ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) - - def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: + def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: try: with get_db() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) + id = name.replace(" ", "_").lower() + res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() log.debug(f"res: {res}") db.commit() - - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - db.commit() return True except Exception as e: log.error(f"delete_tag: {e}") return False - def delete_tag_by_tag_name_and_chat_id_and_user_id( - self, tag_name: str, chat_id: str, user_id: str - ) -> bool: - try: - with get_db() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - db.commit() - - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - db.commit() - - return True - except Exception as e: - log.error(f"delete_tag: {e}") - return False - - def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: - tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) - - for tag in tags: - self.delete_tag_by_tag_name_and_chat_id_and_user_id( - tag.tag_name, chat_id, user_id - ) - - return True - Tags = TagTable() diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index f28b15206..6a9c26f8c 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import ( Chats, ChatTitleIdResponse, ) -from open_webui.apps.webui.models.tags import ( - ChatIdTagForm, - ChatIdTagModel, - TagModel, - Tags, -) +from open_webui.apps.webui.models.tags import TagModel, Tags + from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS @@ -126,6 +122,19 @@ async def search_user_chats( ] +############################ +# GetPinnedChats +############################ + + +@router.get("/pinned", response_model=list[ChatResponse]) +async def get_user_pinned_chats(user=Depends(get_verified_user)): + return [ + ChatResponse(**chat.model_dump()) + for chat in Chats.get_pinned_chats_by_user_id(user.id) + ] + + ############################ # GetChats ############################ @@ -152,6 +161,23 @@ async def get_user_archived_chats(user=Depends(get_verified_user)): ] +############################ +# GetAllTags +############################ + + +@router.get("/all/tags", response_model=list[TagModel]) +async def get_all_user_tags(user=Depends(get_verified_user)): + try: + tags = Tags.get_tags_by_user_id(user.id) + return tags + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetAllChatsInDB ############################ @@ -220,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): ############################ -class TagNameForm(BaseModel): +class TagForm(BaseModel): name: str + + +class TagFilterForm(TagForm): skip: Optional[int] = 0 limit: Optional[int] = 50 @router.post("/tags", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagNameForm, user=Depends(get_verified_user) + form_data: TagFilterForm, user=Depends(get_verified_user) ): - chat_ids = [ - chat_id_tag.chat_id - for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( - form_data.name, user.id - ) - ] - - chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) - + chats = Chats.get_chat_list_by_user_id_and_tag_name( + user.id, form_data.name, form_data.skip, form_data.limit + ) if len(chats) == 0: - Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) return chats -############################ -# GetAllTags -############################ - - -@router.get("/tags/all", response_model=list[TagModel]) -async def get_all_tags(user=Depends(get_verified_user)): - try: - tags = Tags.get_tags_by_user_id(user.id) - return tags - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) - - ############################ # GetChatById ############################ @@ -324,12 +330,45 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified return result +############################ +# GetPinnedStatusById +############################ + + +@router.get("/{id}/pinned", response_model=Optional[bool]) +async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + return chat.pinned + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# PinChatById +############################ + + +@router.post("/{id}/pin", response_model=Optional[ChatResponse]) +async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + chat = Chats.toggle_chat_pinned_by_id(id) + return chat + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # CloneChat ############################ -@router.get("/{id}/clone", response_model=Optional[ChatResponse]) +@router.post("/{id}/clone", response_model=Optional[ChatResponse]) async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: @@ -353,7 +392,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.get("/{id}/archive", response_model=Optional[ChatResponse]) +@router.post("/{id}/archive", response_model=Optional[ChatResponse]) async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: @@ -423,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/tags", response_model=list[TagModel]) async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) - - if tags != None: - return tags + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -438,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) -async def add_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) +@router.post("/{id}/tags", response_model=list[TagModel]) +async def add_tag_by_id_and_tag_name( + id: str, form_data: TagForm, user=Depends(get_verified_user) ): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + tags = chat.meta.get("tags", []) + tag_id = form_data.name.replace(" ", "_").lower() - if form_data.tag_name not in tags: - tag = Tags.add_tag_to_chat(user.id, form_data) - - if tag: - return tag - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, + print(tags, tag_id) + if tag_id not in tags: + Chats.add_chat_tag_by_id_and_user_id_and_tag_name( + id, user.id, form_data.name ) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -465,16 +506,20 @@ async def add_chat_tag_by_id( ############################ -@router.delete("/{id}/tags", response_model=Optional[bool]) -async def delete_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) +@router.delete("/{id}/tags", response_model=list[TagModel]) +async def delete_tag_by_id_and_tag_name( + id: str, form_data: TagForm, user=Depends(get_verified_user) ): - result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( - form_data.tag_name, id, user.id - ) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) - if result: - return result + if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -488,10 +533,17 @@ async def delete_chat_tag_by_id( @router.delete("/{id}/tags/all", response_model=Optional[bool]) async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + Chats.delete_all_tags_by_id_and_user_id(id, user.id) - if result: - return result + for tag in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + tags = chat.meta.get("tags", []) + return Tags.get_tags_by_ids(tags) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND diff --git a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py new file mode 100644 index 000000000..17ac74373 --- /dev/null +++ b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py @@ -0,0 +1,109 @@ +"""Migrate tags + +Revision ID: 1af9b942657b +Revises: 242a2047eae0 +Create Date: 2024-10-09 21:02:35.241684 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import table, select, update, column + +import json + +revision = "1af9b942657b" +down_revision = "242a2047eae0" +branch_labels = None +depends_on = None + + +def upgrade(): + # Step 1: Modify Tag table using batch mode for SQLite support + with op.batch_alter_table("tag", schema=None) as batch_op: + batch_op.create_unique_constraint( + "uq_id_user_id", ["id", "user_id"] + ) # Ensure unique (id, user_id) + batch_op.drop_column("data") + batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True)) + + tag = table( + "tag", + column("id", sa.String()), + column("name", sa.String()), + column("user_id", sa.String()), + column("meta", sa.JSON()), + ) + + # Step 2: Migrate tags + conn = op.get_bind() + result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id)) + + tag_updates = {} + for row in result: + new_id = row.name.replace(" ", "_").lower() + tag_updates[row.id] = new_id + + for tag_id, new_tag_id in tag_updates.items(): + print(f"Updating tag {tag_id} to {new_tag_id}") + if new_tag_id == "pinned": + # delete tag + delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) + conn.execute(delete_stmt) + else: + update_stmt = sa.update(tag).where(tag.c.id == tag_id) + update_stmt = update_stmt.values(id=new_tag_id) + conn.execute(update_stmt) + + # Add columns `pinned` and `meta` to 'chat' + op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True)) + op.add_column( + "chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}") + ) + + chatidtag = table( + "chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String()) + ) + chat = table( + "chat", + column("id", sa.String()), + column("pinned", sa.Boolean()), + column("meta", sa.JSON()), + ) + + # Fetch existing tags + conn = op.get_bind() + result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name)) + + chat_updates = {} + for row in result: + chat_id = row.chat_id + tag_name = row.tag_name.replace(" ", "_").lower() + + if tag_name == "pinned": + # Specifically handle 'pinned' tag + if chat_id not in chat_updates: + chat_updates[chat_id] = {"pinned": True, "meta": {}} + else: + chat_updates[chat_id]["pinned"] = True + else: + if chat_id not in chat_updates: + chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}} + else: + tags = chat_updates[chat_id]["meta"].get("tags", []) + tags.append(tag_name) + + chat_updates[chat_id]["meta"]["tags"] = tags + + # Update chats based on accumulated changes + for chat_id, updates in chat_updates.items(): + update_stmt = sa.update(chat).where(chat.c.id == chat_id) + update_stmt = update_stmt.values( + meta=updates.get("meta", {}), pinned=updates.get("pinned", False) + ) + conn.execute(update_stmt) + pass + + +def downgrade(): + pass diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index ac15f263d..9b2e16278 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -267,7 +267,7 @@ export const getAllUserChats = async (token: string) => { export const getAllChatTags = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, { method: 'GET', headers: { Accept: 'application/json', @@ -295,6 +295,40 @@ export const getAllChatTags = async (token: string) => { return res; }; +export const getPinnedChatList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/pinned`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); +}; + export const getChatListByTagName = async (token: string = '', tagName: string) => { let error = null; @@ -396,11 +430,87 @@ export const getChatByShareId = async (token: string, share_id: string) => { return res; }; +export const getChatPinnedStatusById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pinned`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const toggleChatPinnedStatusById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pin`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const cloneChatById = async (token: string, id: string) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, { - method: 'GET', + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', @@ -470,7 +580,7 @@ export const archiveChatById = async (token: string, id: string) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, { - method: 'GET', + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', @@ -640,8 +750,7 @@ export const addTagById = async (token: string, id: string, tagName: string) => ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - tag_name: tagName, - chat_id: id + name: tagName }) }) .then(async (res) => { @@ -676,8 +785,7 @@ export const deleteTagById = async (token: string, id: string, tagName: string) ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - tag_name: tagName, - chat_id: id + name: tagName }) }) .then(async (res) => { diff --git a/src/lib/components/chat/Tags.svelte b/src/lib/components/chat/Tags.svelte index e6d01b3b5..9ae8c9ff3 100644 --- a/src/lib/components/chat/Tags.svelte +++ b/src/lib/components/chat/Tags.svelte @@ -25,40 +25,30 @@ let tags = []; const getTags = async () => { - return ( - await getTagsById(localStorage.token, chatId).catch(async (error) => { - return []; - }) - ).filter((tag) => tag.name !== 'pinned'); + return await getTagsById(localStorage.token, chatId).catch(async (error) => { + return []; + }); }; const addTag = async (tagName) => { const res = await addTagById(localStorage.token, chatId, tagName); tags = await getTags(); - await updateChatById(localStorage.token, chatId, { tags: tags }); - _tags.set(await getAllChatTags(localStorage.token)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); }; const deleteTag = async (tagName) => { const res = await deleteTagById(localStorage.token, chatId, tagName); tags = await getTags(); - await updateChatById(localStorage.token, chatId, { tags: tags }); await _tags.set(await getAllChatTags(localStorage.token)); if ($_tags.map((t) => t.name).includes(tagName)) { - if (tagName === 'pinned') { - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); - } else { - await chats.set(await getChatListByTagName(localStorage.token, tagName)); - } + await chats.set(await getChatListByTagName(localStorage.token, tagName)); if ($chats.find((chat) => chat.id === chatId)) { dispatch('close'); @@ -67,7 +57,6 @@ // if the tag we deleted is no longer a valid tag, return to main chat list view currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await scrollPaginationEnabled.set(true); } }; diff --git a/src/lib/components/layout/Navbar/Menu.svelte b/src/lib/components/layout/Navbar/Menu.svelte index 42ecae122..85e7b037b 100644 --- a/src/lib/components/layout/Navbar/Menu.svelte +++ b/src/lib/components/layout/Navbar/Menu.svelte @@ -24,6 +24,7 @@ import Clipboard from '$lib/components/icons/Clipboard.svelte'; import AdjustmentsHorizontal from '$lib/components/icons/AdjustmentsHorizontal.svelte'; import Cube from '$lib/components/icons/Cube.svelte'; + import { getChatById } from '$lib/apis/chats'; const i18n = getContext('i18n'); @@ -81,6 +82,9 @@ }; const downloadJSONExport = async () => { + if (chat.id) { + chat = await getChatById(localStorage.token, chat.id); + } let blob = new Blob([JSON.stringify([chat])], { type: 'application/json' }); diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 3fdadf5ff..31c518754 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -34,7 +34,8 @@ archiveChatById, cloneChatById, getChatListBySearchText, - createNewChat + createNewChat, + getPinnedChatList } from '$lib/apis/chats'; import { WEBUI_BASE_URL } from '$lib/constants'; @@ -135,7 +136,7 @@ currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); + await pinnedChats.set(await getPinnedChatList(localStorage.token)); } }; @@ -255,7 +256,7 @@ localStorage.sidebar = value; }); - await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); + await pinnedChats.set(await getPinnedChatList(localStorage.token)); await initChatList(); window.addEventListener('keydown', onKeyDown); @@ -495,7 +496,7 @@ - {#if $tags.filter((t) => t.name !== 'pinned').length > 0} + {#if $tags.length > 0}