From 6703cacb999080be0831121874570c22dd9d92a0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" <timothyjrbeck@gmail.com> Date: Mon, 14 Oct 2024 22:57:11 -0700 Subject: [PATCH] fix: tag unarchive/archive issue --- backend/open_webui/apps/webui/models/chats.py | 21 +++++----- .../open_webui/apps/webui/routers/chats.py | 38 ++++++++++++++++++- src/lib/apis/chats/index.ts | 2 +- src/lib/components/chat/Tags.svelte | 6 +-- src/lib/components/layout/Sidebar.svelte | 13 +++++-- .../layout/Sidebar/ArchivedChatsModal.svelte | 2 - .../components/layout/Sidebar/ChatItem.svelte | 5 ++- .../components/layout/Sidebar/ChatMenu.svelte | 11 ++---- src/routes/(app)/+layout.svelte | 4 +- 9 files changed, 72 insertions(+), 30 deletions(-) diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index 33cab3f18..9032d2e35 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -401,10 +401,11 @@ class ChatTable: # search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags tag_ids = [ - tag_name.replace("tag:", "").replace(" ", "_").lower() - for tag_name in search_text_words - if tag_name.startswith("tag:") + word.replace("tag:", "").replace(" ", "_").lower() + for word in search_text_words + if word.startswith("tag:") ] + search_text_words = [ word for word in search_text_words if not word.startswith("tag:") ] @@ -450,11 +451,11 @@ class ChatTable: EXISTS ( SELECT 1 FROM json_each(Chat.meta, '$.tags') AS tag - WHERE tag.value = :tag_id + WHERE tag.value = :tag_id_{tag_idx} ) """ - ).params(tag_id=tag_id) - for tag_id in tag_ids + ).params(**{f"tag_id_{tag_idx}": tag_id}) + for tag_idx, tag_id in enumerate(tag_ids) ] ) ) @@ -488,11 +489,11 @@ class ChatTable: EXISTS ( SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') AS tag - WHERE tag = :tag_id + WHERE tag = :tag_id_{tag_idx} ) """ - ).params(tag_id=tag_id) - for tag_id in tag_ids + ).params(**{f"tag_id_{tag_idx}": tag_id}) + for tag_idx, tag_id in enumerate(tag_ids) ] ) ) @@ -571,7 +572,7 @@ class ChatTable: def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: with get_db() as db: # Assuming `get_db()` returns a session object - query = db.query(Chat).filter_by(user_id=user_id) + query = db.query(Chat).filter_by(user_id=user_id, archived=False) # Normalize the tag_name for consistency tag_id = tag_name.replace(" ", "_").lower() diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index b919d1447..a91d1e11d 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -114,13 +114,24 @@ async def search_user_chats( limit = 60 skip = (page - 1) * limit - return [ + chat_list = [ ChatTitleIdResponse(**chat.model_dump()) for chat in Chats.get_chats_by_user_id_and_search_text( user.id, text, skip=skip, limit=limit ) ] + # Delete tag if no chat is found + words = text.strip().split(" ") + if page == 1 and len(words) == 1 and words[0].startswith("tag:"): + tag_id = words[0].replace("tag:", "") + if len(chat_list) == 0: + if Tags.get_tag_by_name_and_user_id(tag_id, user.id): + log.debug(f"deleting tag: {tag_id}") + Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + + return chat_list + ############################ # GetPinnedChats @@ -315,7 +326,13 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): if user.role == "admin": + chat = Chats.get_chat_by_id(id) + for tag in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + result = Chats.delete_chat_by_id(id) + return result else: if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get( @@ -326,6 +343,11 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + chat = Chats.get_chat_by_id(id) + for tag in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id) + result = Chats.delete_chat_by_id_and_user_id(id, user.id) return result @@ -397,6 +419,20 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat = Chats.toggle_chat_archive_by_id(id) + + # Delete tags if chat is archived + if chat.archived: + for tag_id in chat.meta.get("tags", []): + if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0: + log.debug(f"deleting tag: {tag_id}") + Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + else: + for tag_id in chat.meta.get("tags", []): + tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id) + if tag is None: + log.debug(f"inserting tag: {tag_id}") + tag = Tags.insert_new_tag(tag_id, user.id) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index ff89fdf43..6056f6dbf 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -267,7 +267,7 @@ export const getAllUserChats = async (token: string) => { return res; }; -export const getAllChatTags = async (token: string) => { +export const getAllTags = async (token: string) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, { diff --git a/src/lib/components/chat/Tags.svelte b/src/lib/components/chat/Tags.svelte index 1ce3bc658..f71ff0af9 100644 --- a/src/lib/components/chat/Tags.svelte +++ b/src/lib/components/chat/Tags.svelte @@ -2,7 +2,7 @@ import { addTagById, deleteTagById, - getAllChatTags, + getAllTags, getChatList, getChatListByTagName, getTagsById, @@ -37,7 +37,7 @@ tags: tags }); - await _tags.set(await getAllChatTags(localStorage.token)); + await _tags.set(await getAllTags(localStorage.token)); dispatch('add', { name: tagName }); @@ -50,7 +50,7 @@ tags: tags }); - await _tags.set(await getAllChatTags(localStorage.token)); + await _tags.set(await getAllTags(localStorage.token)); dispatch('delete', { name: tagName }); diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 583624a28..cfe5b9976 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -30,7 +30,7 @@ getChatById, getChatListByTagName, updateChatById, - getAllChatTags, + getAllTags, archiveChatById, cloneChatById, getChatListBySearchText, @@ -77,6 +77,8 @@ const initChatList = async () => { // Reset pagination variables + tags.set(await getAllTags(localStorage.token)); + currentChatPage.set(1); allChatsLoaded = false; await chats.set(await getChatList(localStorage.token, $currentChatPage)); @@ -123,6 +125,10 @@ searchDebounceTimeout = setTimeout(async () => { currentChatPage.set(1); await chats.set(await getChatListBySearchText(localStorage.token, search)); + + if ($chats.length === 0) { + tags.set(await getAllTags(localStorage.token)); + } }, 1000); } }; @@ -134,6 +140,8 @@ }); if (res) { + tags.set(await getAllTags(localStorage.token)); + if ($chatId === id) { await chatId.set(''); await tick(); @@ -143,7 +151,6 @@ allChatsLoaded = false; currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); - await pinnedChats.set(await getPinnedChatList(localStorage.token)); } }; @@ -324,7 +331,7 @@ bind:show={$showArchivedChats} on:change={async () => { await pinnedChats.set(await getPinnedChatList(localStorage.token)); - await chats.set(await getChatList(localStorage.token)); + await initChatList(); }} /> diff --git a/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte b/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte index 80e3f1579..b2ae11058 100644 --- a/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte +++ b/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte @@ -15,7 +15,6 @@ getArchivedChatList } from '$lib/apis/chats'; import Tooltip from '$lib/components/common/Tooltip.svelte'; - const i18n = getContext('i18n'); export let show = false; @@ -30,7 +29,6 @@ }); chats = await getArchivedChatList(localStorage.token); - dispatch('change'); }; diff --git a/src/lib/components/layout/Sidebar/ChatItem.svelte b/src/lib/components/layout/Sidebar/ChatItem.svelte index 8356b24cc..ff96ed05f 100644 --- a/src/lib/components/layout/Sidebar/ChatItem.svelte +++ b/src/lib/components/layout/Sidebar/ChatItem.svelte @@ -10,6 +10,7 @@ archiveChatById, cloneChatById, deleteChatById, + getAllTags, getChatList, getChatListByTagName, getPinnedChatList, @@ -22,7 +23,8 @@ mobile, pinnedChats, showSidebar, - currentChatPage + currentChatPage, + tags } from '$lib/stores'; import ChatMenu from './ChatMenu.svelte'; @@ -77,6 +79,7 @@ const archiveChatHandler = async (id) => { await archiveChatById(localStorage.token, id); + tags.set(await getAllTags(localStorage.token)); currentChatPage.set(1); await chats.set(await getChatList(localStorage.token, $currentChatPage)); diff --git a/src/lib/components/layout/Sidebar/ChatMenu.svelte b/src/lib/components/layout/Sidebar/ChatMenu.svelte index 306cd3633..38d0f1182 100644 --- a/src/lib/components/layout/Sidebar/ChatMenu.svelte +++ b/src/lib/components/layout/Sidebar/ChatMenu.svelte @@ -15,13 +15,8 @@ import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte'; import Bookmark from '$lib/components/icons/Bookmark.svelte'; import BookmarkSlash from '$lib/components/icons/BookmarkSlash.svelte'; - import { - addTagById, - deleteTagById, - getChatPinnedStatusById, - getTagsById, - toggleChatPinnedStatusById - } from '$lib/apis/chats'; + import { getChatPinnedStatusById, toggleChatPinnedStatusById } from '$lib/apis/chats'; + import { chats } from '$lib/stores'; const i18n = getContext('i18n'); @@ -146,6 +141,7 @@ type: 'add', name: e.detail.name }); + show = false; }} on:delete={(e) => { @@ -153,6 +149,7 @@ type: 'delete', name: e.detail.name }); + show = false; }} on:close={() => { diff --git a/src/routes/(app)/+layout.svelte b/src/routes/(app)/+layout.svelte index bd711470a..2cc9b5c3d 100644 --- a/src/routes/(app)/+layout.svelte +++ b/src/routes/(app)/+layout.svelte @@ -13,7 +13,7 @@ import { getKnowledgeItems } from '$lib/apis/knowledge'; import { getFunctions } from '$lib/apis/functions'; import { getModels as _getModels, getVersionUpdates } from '$lib/apis'; - import { getAllChatTags } from '$lib/apis/chats'; + import { getAllTags } from '$lib/apis/chats'; import { getPrompts } from '$lib/apis/prompts'; import { getTools } from '$lib/apis/tools'; import { getBanners } from '$lib/apis/configs'; @@ -117,7 +117,7 @@ banners.set(await getBanners(localStorage.token)); })(), (async () => { - tags.set(await getAllChatTags(localStorage.token)); + tags.set(await getAllTags(localStorage.token)); })() ]);