refac: tags

This commit is contained in:
Timothy J. Baek 2024-10-10 23:22:53 -07:00
parent 4adc57fd34
commit acb5dcf30a
10 changed files with 555 additions and 291 deletions

View File

@ -4,10 +4,13 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
#################### ####################
# Chat DB Schema # Chat DB Schema
@ -27,6 +30,9 @@ class Chat(Base):
share_id = Column(Text, unique=True, nullable=True) share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False) archived = Column(Boolean, default=False)
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}")
class ChatModel(BaseModel): class ChatModel(BaseModel):
@ -42,6 +48,9 @@ class ChatModel(BaseModel):
share_id: Optional[str] = None share_id: Optional[str] = None
archived: bool = False archived: bool = False
pinned: Optional[bool] = False
meta: dict = {}
#################### ####################
@ -66,6 +75,8 @@ class ChatResponse(BaseModel):
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared share_id: Optional[str] = None # id of the chat to be shared
archived: bool archived: bool
pinned: Optional[bool] = False
meta: dict = {}
class ChatTitleIdResponse(BaseModel): class ChatTitleIdResponse(BaseModel):
@ -184,11 +195,24 @@ class ChatTable:
except Exception: except Exception:
return None return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
chat.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(chat) db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
@ -330,6 +354,15 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
@ -383,6 +416,135 @@ class ChatTable:
paginated_chats = filtered_chats[skip : skip + limit] paginated_chats = filtered_chats[skip : skip + limit]
return [ChatModel.model_validate(chat) for chat in paginated_chats] return [ChatModel.model_validate(chat) for chat in paginated_chats]
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name)
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
all_chats = query.all()
print("all_chats", all_chats)
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
try:
with get_db() as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": chat.meta.get("tags", []) + [tag_id],
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Get the count of matching records
count = query.count()
# Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count)
return count
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower()
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": tags,
}
db.commit()
return True
except Exception:
return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.meta = {
**chat.meta,
"tags": [],
}
db.commit()
return True
except Exception:
return False
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:

View File

@ -4,53 +4,32 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text from sqlalchemy import BigInteger, Column, String, JSON
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Tag DB Schema # Tag DB Schema
#################### ####################
class Tag(Base): class Tag(Base):
__tablename__ = "tag" __tablename__ = "tag"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
name = Column(String) name = Column(String)
user_id = Column(String) user_id = Column(String)
data = Column(Text, nullable=True) meta = Column(JSON, nullable=True)
class ChatIdTag(Base):
__tablename__ = "chatidtag"
id = Column(String, primary_key=True)
tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel): class TagModel(BaseModel):
id: str id: str
name: str name: str
user_id: str user_id: str
data: Optional[str] = None meta: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel):
id: str
tag_name: str
chat_id: str
user_id: str
timestamp: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -59,23 +38,15 @@ class ChatIdTagModel(BaseModel):
#################### ####################
class ChatIdTagForm(BaseModel): class TagChatIdForm(BaseModel):
tag_name: str name: str
chat_id: str chat_id: str
class TagChatIdsResponse(BaseModel):
chat_ids: list[str]
class ChatTagsResponse(BaseModel):
tags: list[str]
class TagTable: class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
result = Tag(**tag.model_dump()) result = Tag(**tag.model_dump())
@ -93,170 +64,38 @@ class TagTable:
self, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
id = name.replace(" ", "_").lower()
with get_db() as db: with get_db() as db:
tag = db.query(Tag).filter_by(name=name, user_id=user_id).first() tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception: except Exception:
return None return None
def add_tag_to_chat(
self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag is None:
tag = self.insert_new_tag(form_data.tag_name, user_id)
id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel(
**{
"id": id,
"user_id": user_id,
"chat_id": form_data.chat_id,
"tag_name": tag.name,
"timestamp": int(time.time()),
}
)
try:
with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ChatIdTagModel.model_validate(result)
else:
return None
except Exception:
return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
with get_db() as db: with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (db.query(Tag).filter_by(user_id=user_id).all())
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
] ]
def get_tags_by_chat_id_and_user_id( def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
self, chat_id: str, user_id: str
) -> list[TagModel]:
with get_db() as db: with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
] ]
def get_chat_ids_by_tag_name_and_user_id( def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
self, tag_name: str, user_id: str
) -> list[ChatIdTagModel]:
with get_db() as db:
return [
ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
with get_db() as db:
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
res = ( id = name.replace(" ", "_").lower()
db.query(ChatIdTag) res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
)
log.debug(f"res: {res}") log.debug(f"res: {res}")
db.commit() db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
return False return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
)
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
tag.tag_name, chat_id, user_id
)
return True
Tags = TagTable() Tags = TagTable()

View File

@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import (
Chats, Chats,
ChatTitleIdResponse, ChatTitleIdResponse,
) )
from open_webui.apps.webui.models.tags import ( from open_webui.apps.webui.models.tags import TagModel, Tags
ChatIdTagForm,
ChatIdTagModel,
TagModel,
Tags,
)
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -126,6 +122,19 @@ async def search_user_chats(
] ]
############################
# GetPinnedChats
############################
@router.get("/pinned", response_model=list[ChatResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
]
############################ ############################
# GetChats # GetChats
############################ ############################
@ -152,6 +161,23 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
] ]
############################
# GetAllTags
############################
@router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetAllChatsInDB # GetAllChatsInDB
############################ ############################
@ -220,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
############################ ############################
class TagNameForm(BaseModel): class TagForm(BaseModel):
name: str name: str
class TagFilterForm(TagForm):
skip: Optional[int] = 0 skip: Optional[int] = 0
limit: Optional[int] = 50 limit: Optional[int] = 50
@router.post("/tags", response_model=list[ChatTitleIdResponse]) @router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user) form_data: TagFilterForm, user=Depends(get_verified_user)
): ):
chat_ids = [ chats = Chats.get_chat_list_by_user_id_and_tag_name(
chat_id_tag.chat_id user.id, form_data.name, form_data.skip, form_data.limit
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( )
form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
if len(chats) == 0: if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
return chats return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=list[TagModel])
async def get_all_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetChatById # GetChatById
############################ ############################
@ -324,12 +330,45 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
return result return result
############################
# GetPinnedStatusById
############################
@router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return chat.pinned
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# PinChatById
############################
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_pinned_by_id(id)
return chat
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# CloneChat # CloneChat
############################ ############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse]) @router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)): async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
@ -353,7 +392,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
############################ ############################
@router.get("/{id}/archive", response_model=Optional[ChatResponse]) @router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
@ -423,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/tags", response_model=list[TagModel]) @router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if tags != None: tags = chat.meta.get("tags", [])
return tags return Tags.get_tags_by_ids(tags)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -438,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
############################ ############################
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) @router.post("/{id}/tags", response_model=list[TagModel])
async def add_chat_tag_by_id( async def add_tag_by_id_and_tag_name(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) id: str, form_data: TagForm, user=Depends(get_verified_user)
): ):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower()
if form_data.tag_name not in tags: print(tags, tag_id)
tag = Tags.add_tag_to_chat(user.id, form_data) if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
if tag: id, user.id, form_data.name
return tag
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
) )
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@ -465,16 +506,20 @@ async def add_chat_tag_by_id(
############################ ############################
@router.delete("/{id}/tags", response_model=Optional[bool]) @router.delete("/{id}/tags", response_model=list[TagModel])
async def delete_chat_tag_by_id( async def delete_tag_by_id_and_tag_name(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user) id: str, form_data: TagForm, user=Depends(get_verified_user)
): ):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( chat = Chats.get_chat_by_id_and_user_id(id, user.id)
form_data.tag_name, id, user.id if chat:
) Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
if result: if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
return result Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -488,10 +533,17 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool]) @router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)): async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
if result: for tag in chat.meta.get("tags", []):
return result if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND

View File

@ -0,0 +1,109 @@
"""Migrate tags
Revision ID: 1af9b942657b
Revises: 242a2047eae0
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
import json
revision = "1af9b942657b"
down_revision = "242a2047eae0"
branch_labels = None
depends_on = None
def upgrade():
# Step 1: Modify Tag table using batch mode for SQLite support
with op.batch_alter_table("tag", schema=None) as batch_op:
batch_op.create_unique_constraint(
"uq_id_user_id", ["id", "user_id"]
) # Ensure unique (id, user_id)
batch_op.drop_column("data")
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
tag = table(
"tag",
column("id", sa.String()),
column("name", sa.String()),
column("user_id", sa.String()),
column("meta", sa.JSON()),
)
# Step 2: Migrate tags
conn = op.get_bind()
result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
tag_updates = {}
for row in result:
new_id = row.name.replace(" ", "_").lower()
tag_updates[row.id] = new_id
for tag_id, new_tag_id in tag_updates.items():
print(f"Updating tag {tag_id} to {new_tag_id}")
if new_tag_id == "pinned":
# delete tag
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
update_stmt = sa.update(tag).where(tag.c.id == tag_id)
update_stmt = update_stmt.values(id=new_tag_id)
conn.execute(update_stmt)
# Add columns `pinned` and `meta` to 'chat'
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
op.add_column(
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
)
chatidtag = table(
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
)
chat = table(
"chat",
column("id", sa.String()),
column("pinned", sa.Boolean()),
column("meta", sa.JSON()),
)
# Fetch existing tags
conn = op.get_bind()
result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
chat_updates = {}
for row in result:
chat_id = row.chat_id
tag_name = row.tag_name.replace(" ", "_").lower()
if tag_name == "pinned":
# Specifically handle 'pinned' tag
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": True, "meta": {}}
else:
chat_updates[chat_id]["pinned"] = True
else:
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
else:
tags = chat_updates[chat_id]["meta"].get("tags", [])
tags.append(tag_name)
chat_updates[chat_id]["meta"]["tags"] = tags
# Update chats based on accumulated changes
for chat_id, updates in chat_updates.items():
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
update_stmt = update_stmt.values(
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
)
conn.execute(update_stmt)
pass
def downgrade():
pass

View File

@ -267,7 +267,7 @@ export const getAllUserChats = async (token: string) => {
export const getAllChatTags = async (token: string) => { export const getAllChatTags = async (token: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -295,6 +295,40 @@ export const getAllChatTags = async (token: string) => {
return res; return res;
}; };
export const getPinnedChatList = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/pinned`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res.map((chat) => ({
...chat,
time_range: getTimeRange(chat.updated_at)
}));
};
export const getChatListByTagName = async (token: string = '', tagName: string) => { export const getChatListByTagName = async (token: string = '', tagName: string) => {
let error = null; let error = null;
@ -396,11 +430,87 @@ export const getChatByShareId = async (token: string, share_id: string) => {
return res; return res;
}; };
export const getChatPinnedStatusById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pinned`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const toggleChatPinnedStatusById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pin`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const cloneChatById = async (token: string, id: string) => { export const cloneChatById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, {
method: 'GET', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -470,7 +580,7 @@ export const archiveChatById = async (token: string, id: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, {
method: 'GET', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -640,8 +750,7 @@ export const addTagById = async (token: string, id: string, tagName: string) =>
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
tag_name: tagName, name: tagName
chat_id: id
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -676,8 +785,7 @@ export const deleteTagById = async (token: string, id: string, tagName: string)
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
tag_name: tagName, name: tagName
chat_id: id
}) })
}) })
.then(async (res) => { .then(async (res) => {

View File

@ -25,40 +25,30 @@
let tags = []; let tags = [];
const getTags = async () => { const getTags = async () => {
return ( return await getTagsById(localStorage.token, chatId).catch(async (error) => {
await getTagsById(localStorage.token, chatId).catch(async (error) => { return [];
return []; });
})
).filter((tag) => tag.name !== 'pinned');
}; };
const addTag = async (tagName) => { const addTag = async (tagName) => {
const res = await addTagById(localStorage.token, chatId, tagName); const res = await addTagById(localStorage.token, chatId, tagName);
tags = await getTags(); tags = await getTags();
await updateChatById(localStorage.token, chatId, { await updateChatById(localStorage.token, chatId, {
tags: tags tags: tags
}); });
_tags.set(await getAllChatTags(localStorage.token)); _tags.set(await getAllChatTags(localStorage.token));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
}; };
const deleteTag = async (tagName) => { const deleteTag = async (tagName) => {
const res = await deleteTagById(localStorage.token, chatId, tagName); const res = await deleteTagById(localStorage.token, chatId, tagName);
tags = await getTags(); tags = await getTags();
await updateChatById(localStorage.token, chatId, { await updateChatById(localStorage.token, chatId, {
tags: tags tags: tags
}); });
await _tags.set(await getAllChatTags(localStorage.token)); await _tags.set(await getAllChatTags(localStorage.token));
if ($_tags.map((t) => t.name).includes(tagName)) { if ($_tags.map((t) => t.name).includes(tagName)) {
if (tagName === 'pinned') { await chats.set(await getChatListByTagName(localStorage.token, tagName));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
} else {
await chats.set(await getChatListByTagName(localStorage.token, tagName));
}
if ($chats.find((chat) => chat.id === chatId)) { if ($chats.find((chat) => chat.id === chatId)) {
dispatch('close'); dispatch('close');
@ -67,7 +57,6 @@
// if the tag we deleted is no longer a valid tag, return to main chat list view // if the tag we deleted is no longer a valid tag, return to main chat list view
currentChatPage.set(1); currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage)); await chats.set(await getChatList(localStorage.token, $currentChatPage));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
await scrollPaginationEnabled.set(true); await scrollPaginationEnabled.set(true);
} }
}; };

View File

@ -24,6 +24,7 @@
import Clipboard from '$lib/components/icons/Clipboard.svelte'; import Clipboard from '$lib/components/icons/Clipboard.svelte';
import AdjustmentsHorizontal from '$lib/components/icons/AdjustmentsHorizontal.svelte'; import AdjustmentsHorizontal from '$lib/components/icons/AdjustmentsHorizontal.svelte';
import Cube from '$lib/components/icons/Cube.svelte'; import Cube from '$lib/components/icons/Cube.svelte';
import { getChatById } from '$lib/apis/chats';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -81,6 +82,9 @@
}; };
const downloadJSONExport = async () => { const downloadJSONExport = async () => {
if (chat.id) {
chat = await getChatById(localStorage.token, chat.id);
}
let blob = new Blob([JSON.stringify([chat])], { let blob = new Blob([JSON.stringify([chat])], {
type: 'application/json' type: 'application/json'
}); });

View File

@ -34,7 +34,8 @@
archiveChatById, archiveChatById,
cloneChatById, cloneChatById,
getChatListBySearchText, getChatListBySearchText,
createNewChat createNewChat,
getPinnedChatList
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
@ -135,7 +136,7 @@
currentChatPage.set(1); currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage)); await chats.set(await getChatList(localStorage.token, $currentChatPage));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await pinnedChats.set(await getPinnedChatList(localStorage.token));
} }
}; };
@ -255,7 +256,7 @@
localStorage.sidebar = value; localStorage.sidebar = value;
}); });
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await pinnedChats.set(await getPinnedChatList(localStorage.token));
await initChatList(); await initChatList();
window.addEventListener('keydown', onKeyDown); window.addEventListener('keydown', onKeyDown);
@ -495,7 +496,7 @@
</div> </div>
</div> </div>
{#if $tags.filter((t) => t.name !== 'pinned').length > 0} {#if $tags.length > 0}
<div class="px-3.5 mb-1 flex gap-0.5 flex-wrap"> <div class="px-3.5 mb-1 flex gap-0.5 flex-wrap">
<button <button
class="px-2.5 py-[1px] text-xs transition {selectedTagName === null class="px-2.5 py-[1px] text-xs transition {selectedTagName === null
@ -508,7 +509,7 @@
> >
{$i18n.t('all')} {$i18n.t('all')}
</button> </button>
{#each $tags.filter((t) => t.name !== 'pinned') as tag} {#each $tags as tag}
<button <button
class="px-2.5 py-[1px] text-xs transition {selectedTagName === tag.name class="px-2.5 py-[1px] text-xs transition {selectedTagName === tag.name
? 'bg-gray-100 dark:bg-gray-900' ? 'bg-gray-100 dark:bg-gray-900'
@ -516,14 +517,15 @@
on:click={async () => { on:click={async () => {
selectedTagName = tag.name; selectedTagName = tag.name;
scrollPaginationEnabled.set(false); scrollPaginationEnabled.set(false);
let chatIds = await getChatListByTagName(localStorage.token, tag.name);
if (chatIds.length === 0) {
await tags.set(await getAllChatTags(localStorage.token));
let taggedChatList = await getChatListByTagName(localStorage.token, tag.name);
if (taggedChatList.length === 0) {
await tags.set(await getAllChatTags(localStorage.token));
// if the tag we deleted is no longer a valid tag, return to main chat list view // if the tag we deleted is no longer a valid tag, return to main chat list view
await initChatList(); await initChatList();
} else {
await chats.set(taggedChatList);
} }
await chats.set(chatIds);
chatListLoading = false; chatListLoading = false;
}} }}
> >

View File

@ -12,6 +12,7 @@
deleteChatById, deleteChatById,
getChatList, getChatList,
getChatListByTagName, getChatListByTagName,
getPinnedChatList,
updateChatById updateChatById
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { import {
@ -55,7 +56,7 @@
currentChatPage.set(1); currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage)); await chats.set(await getChatList(localStorage.token, $currentChatPage));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await pinnedChats.set(await getPinnedChatList(localStorage.token));
} }
}; };
@ -70,7 +71,7 @@
currentChatPage.set(1); currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage)); await chats.set(await getChatList(localStorage.token, $currentChatPage));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await pinnedChats.set(await getPinnedChatList(localStorage.token));
} }
}; };
@ -79,7 +80,7 @@
currentChatPage.set(1); currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage)); await chats.set(await getChatList(localStorage.token, $currentChatPage));
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await pinnedChats.set(await getPinnedChatList(localStorage.token));
}; };
const focusEdit = async (node: HTMLInputElement) => { const focusEdit = async (node: HTMLInputElement) => {
@ -256,7 +257,7 @@
dispatch('unselect'); dispatch('unselect');
}} }}
on:change={async () => { on:change={async () => {
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await pinnedChats.set(await getPinnedChatList(localStorage.token));
}} }}
> >
<button <button

View File

@ -15,7 +15,13 @@
import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte'; import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte';
import Bookmark from '$lib/components/icons/Bookmark.svelte'; import Bookmark from '$lib/components/icons/Bookmark.svelte';
import BookmarkSlash from '$lib/components/icons/BookmarkSlash.svelte'; import BookmarkSlash from '$lib/components/icons/BookmarkSlash.svelte';
import { addTagById, deleteTagById, getTagsById } from '$lib/apis/chats'; import {
addTagById,
deleteTagById,
getChatPinnedStatusById,
getTagsById,
toggleChatPinnedStatusById
} from '$lib/apis/chats';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -32,20 +38,12 @@
let pinned = false; let pinned = false;
const pinHandler = async () => { const pinHandler = async () => {
if (pinned) { await toggleChatPinnedStatusById(localStorage.token, chatId);
await deleteTagById(localStorage.token, chatId, 'pinned');
} else {
await addTagById(localStorage.token, chatId, 'pinned');
}
dispatch('change'); dispatch('change');
}; };
const checkPinned = async () => { const checkPinned = async () => {
pinned = ( pinned = await getChatPinnedStatusById(localStorage.token, chatId);
await getTagsById(localStorage.token, chatId).catch(async (error) => {
return [];
})
).find((tag) => tag.name === 'pinned');
}; };
$: if (show) { $: if (show) {