From 6703cacb999080be0831121874570c22dd9d92a0 Mon Sep 17 00:00:00 2001
From: "Timothy J. Baek" <timothyjrbeck@gmail.com>
Date: Mon, 14 Oct 2024 22:57:11 -0700
Subject: [PATCH] fix: tag unarchive/archive issue

---
 backend/open_webui/apps/webui/models/chats.py | 21 +++++-----
 .../open_webui/apps/webui/routers/chats.py    | 38 ++++++++++++++++++-
 src/lib/apis/chats/index.ts                   |  2 +-
 src/lib/components/chat/Tags.svelte           |  6 +--
 src/lib/components/layout/Sidebar.svelte      | 13 +++++--
 .../layout/Sidebar/ArchivedChatsModal.svelte  |  2 -
 .../components/layout/Sidebar/ChatItem.svelte |  5 ++-
 .../components/layout/Sidebar/ChatMenu.svelte | 11 ++----
 src/routes/(app)/+layout.svelte               |  4 +-
 9 files changed, 72 insertions(+), 30 deletions(-)

diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py
index 33cab3f18..9032d2e35 100644
--- a/backend/open_webui/apps/webui/models/chats.py
+++ b/backend/open_webui/apps/webui/models/chats.py
@@ -401,10 +401,11 @@ class ChatTable:
 
         # search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
         tag_ids = [
-            tag_name.replace("tag:", "").replace(" ", "_").lower()
-            for tag_name in search_text_words
-            if tag_name.startswith("tag:")
+            word.replace("tag:", "").replace(" ", "_").lower()
+            for word in search_text_words
+            if word.startswith("tag:")
         ]
+
         search_text_words = [
             word for word in search_text_words if not word.startswith("tag:")
         ]
@@ -450,11 +451,11 @@ class ChatTable:
                                     EXISTS (
                                         SELECT 1
                                         FROM json_each(Chat.meta, '$.tags') AS tag
-                                        WHERE tag.value = :tag_id
+                                        WHERE tag.value = :tag_id_{tag_idx}
                                     )
                                     """
-                                ).params(tag_id=tag_id)
-                                for tag_id in tag_ids
+                                ).params(**{f"tag_id_{tag_idx}": tag_id})
+                                for tag_idx, tag_id in enumerate(tag_ids)
                             ]
                         )
                     )
@@ -488,11 +489,11 @@ class ChatTable:
                                     EXISTS (
                                         SELECT 1
                                         FROM json_array_elements_text(Chat.meta->'tags') AS tag
-                                        WHERE tag = :tag_id
+                                        WHERE tag = :tag_id_{tag_idx}
                                     )
                                     """
-                                ).params(tag_id=tag_id)
-                                for tag_id in tag_ids
+                                ).params(**{f"tag_id_{tag_idx}": tag_id})
+                                for tag_idx, tag_id in enumerate(tag_ids)
                             ]
                         )
                     )
@@ -571,7 +572,7 @@ class ChatTable:
 
     def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
         with get_db() as db:  # Assuming `get_db()` returns a session object
-            query = db.query(Chat).filter_by(user_id=user_id)
+            query = db.query(Chat).filter_by(user_id=user_id, archived=False)
 
             # Normalize the tag_name for consistency
             tag_id = tag_name.replace(" ", "_").lower()
diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py
index b919d1447..a91d1e11d 100644
--- a/backend/open_webui/apps/webui/routers/chats.py
+++ b/backend/open_webui/apps/webui/routers/chats.py
@@ -114,13 +114,24 @@ async def search_user_chats(
     limit = 60
     skip = (page - 1) * limit
 
-    return [
+    chat_list = [
         ChatTitleIdResponse(**chat.model_dump())
         for chat in Chats.get_chats_by_user_id_and_search_text(
             user.id, text, skip=skip, limit=limit
         )
     ]
 
+    # Delete tag if no chat is found
+    words = text.strip().split(" ")
+    if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
+        tag_id = words[0].replace("tag:", "")
+        if len(chat_list) == 0:
+            if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
+                log.debug(f"deleting tag: {tag_id}")
+                Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
+
+    return chat_list
+
 
 ############################
 # GetPinnedChats
@@ -315,7 +326,13 @@ async def update_chat_by_id(
 @router.delete("/{id}", response_model=bool)
 async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
     if user.role == "admin":
+        chat = Chats.get_chat_by_id(id)
+        for tag in chat.meta.get("tags", []):
+            if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
+                Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
         result = Chats.delete_chat_by_id(id)
+
         return result
     else:
         if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
@@ -326,6 +343,11 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             )
 
+        chat = Chats.get_chat_by_id(id)
+        for tag in chat.meta.get("tags", []):
+            if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
+                Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
         result = Chats.delete_chat_by_id_and_user_id(id, user.id)
         return result
 
@@ -397,6 +419,20 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
         chat = Chats.toggle_chat_archive_by_id(id)
+
+        # Delete tags if chat is archived
+        if chat.archived:
+            for tag_id in chat.meta.get("tags", []):
+                if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
+                    log.debug(f"deleting tag: {tag_id}")
+                    Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
+        else:
+            for tag_id in chat.meta.get("tags", []):
+                tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
+                if tag is None:
+                    log.debug(f"inserting tag: {tag_id}")
+                    tag = Tags.insert_new_tag(tag_id, user.id)
+
         return ChatResponse(**chat.model_dump())
     else:
         raise HTTPException(
diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts
index ff89fdf43..6056f6dbf 100644
--- a/src/lib/apis/chats/index.ts
+++ b/src/lib/apis/chats/index.ts
@@ -267,7 +267,7 @@ export const getAllUserChats = async (token: string) => {
 	return res;
 };
 
-export const getAllChatTags = async (token: string) => {
+export const getAllTags = async (token: string) => {
 	let error = null;
 
 	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, {
diff --git a/src/lib/components/chat/Tags.svelte b/src/lib/components/chat/Tags.svelte
index 1ce3bc658..f71ff0af9 100644
--- a/src/lib/components/chat/Tags.svelte
+++ b/src/lib/components/chat/Tags.svelte
@@ -2,7 +2,7 @@
 	import {
 		addTagById,
 		deleteTagById,
-		getAllChatTags,
+		getAllTags,
 		getChatList,
 		getChatListByTagName,
 		getTagsById,
@@ -37,7 +37,7 @@
 			tags: tags
 		});
 
-		await _tags.set(await getAllChatTags(localStorage.token));
+		await _tags.set(await getAllTags(localStorage.token));
 		dispatch('add', {
 			name: tagName
 		});
@@ -50,7 +50,7 @@
 			tags: tags
 		});
 
-		await _tags.set(await getAllChatTags(localStorage.token));
+		await _tags.set(await getAllTags(localStorage.token));
 		dispatch('delete', {
 			name: tagName
 		});
diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte
index 583624a28..cfe5b9976 100644
--- a/src/lib/components/layout/Sidebar.svelte
+++ b/src/lib/components/layout/Sidebar.svelte
@@ -30,7 +30,7 @@
 		getChatById,
 		getChatListByTagName,
 		updateChatById,
-		getAllChatTags,
+		getAllTags,
 		archiveChatById,
 		cloneChatById,
 		getChatListBySearchText,
@@ -77,6 +77,8 @@
 
 	const initChatList = async () => {
 		// Reset pagination variables
+		tags.set(await getAllTags(localStorage.token));
+
 		currentChatPage.set(1);
 		allChatsLoaded = false;
 		await chats.set(await getChatList(localStorage.token, $currentChatPage));
@@ -123,6 +125,10 @@
 			searchDebounceTimeout = setTimeout(async () => {
 				currentChatPage.set(1);
 				await chats.set(await getChatListBySearchText(localStorage.token, search));
+
+				if ($chats.length === 0) {
+					tags.set(await getAllTags(localStorage.token));
+				}
 			}, 1000);
 		}
 	};
@@ -134,6 +140,8 @@
 		});
 
 		if (res) {
+			tags.set(await getAllTags(localStorage.token));
+
 			if ($chatId === id) {
 				await chatId.set('');
 				await tick();
@@ -143,7 +151,6 @@
 			allChatsLoaded = false;
 			currentChatPage.set(1);
 			await chats.set(await getChatList(localStorage.token, $currentChatPage));
-
 			await pinnedChats.set(await getPinnedChatList(localStorage.token));
 		}
 	};
@@ -324,7 +331,7 @@
 	bind:show={$showArchivedChats}
 	on:change={async () => {
 		await pinnedChats.set(await getPinnedChatList(localStorage.token));
-		await chats.set(await getChatList(localStorage.token));
+		await initChatList();
 	}}
 />
 
diff --git a/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte b/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte
index 80e3f1579..b2ae11058 100644
--- a/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte
+++ b/src/lib/components/layout/Sidebar/ArchivedChatsModal.svelte
@@ -15,7 +15,6 @@
 		getArchivedChatList
 	} from '$lib/apis/chats';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
-
 	const i18n = getContext('i18n');
 
 	export let show = false;
@@ -30,7 +29,6 @@
 		});
 
 		chats = await getArchivedChatList(localStorage.token);
-
 		dispatch('change');
 	};
 
diff --git a/src/lib/components/layout/Sidebar/ChatItem.svelte b/src/lib/components/layout/Sidebar/ChatItem.svelte
index 8356b24cc..ff96ed05f 100644
--- a/src/lib/components/layout/Sidebar/ChatItem.svelte
+++ b/src/lib/components/layout/Sidebar/ChatItem.svelte
@@ -10,6 +10,7 @@
 		archiveChatById,
 		cloneChatById,
 		deleteChatById,
+		getAllTags,
 		getChatList,
 		getChatListByTagName,
 		getPinnedChatList,
@@ -22,7 +23,8 @@
 		mobile,
 		pinnedChats,
 		showSidebar,
-		currentChatPage
+		currentChatPage,
+		tags
 	} from '$lib/stores';
 
 	import ChatMenu from './ChatMenu.svelte';
@@ -77,6 +79,7 @@
 
 	const archiveChatHandler = async (id) => {
 		await archiveChatById(localStorage.token, id);
+		tags.set(await getAllTags(localStorage.token));
 
 		currentChatPage.set(1);
 		await chats.set(await getChatList(localStorage.token, $currentChatPage));
diff --git a/src/lib/components/layout/Sidebar/ChatMenu.svelte b/src/lib/components/layout/Sidebar/ChatMenu.svelte
index 306cd3633..38d0f1182 100644
--- a/src/lib/components/layout/Sidebar/ChatMenu.svelte
+++ b/src/lib/components/layout/Sidebar/ChatMenu.svelte
@@ -15,13 +15,8 @@
 	import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte';
 	import Bookmark from '$lib/components/icons/Bookmark.svelte';
 	import BookmarkSlash from '$lib/components/icons/BookmarkSlash.svelte';
-	import {
-		addTagById,
-		deleteTagById,
-		getChatPinnedStatusById,
-		getTagsById,
-		toggleChatPinnedStatusById
-	} from '$lib/apis/chats';
+	import { getChatPinnedStatusById, toggleChatPinnedStatusById } from '$lib/apis/chats';
+	import { chats } from '$lib/stores';
 
 	const i18n = getContext('i18n');
 
@@ -146,6 +141,7 @@
 							type: 'add',
 							name: e.detail.name
 						});
+
 						show = false;
 					}}
 					on:delete={(e) => {
@@ -153,6 +149,7 @@
 							type: 'delete',
 							name: e.detail.name
 						});
+
 						show = false;
 					}}
 					on:close={() => {
diff --git a/src/routes/(app)/+layout.svelte b/src/routes/(app)/+layout.svelte
index bd711470a..2cc9b5c3d 100644
--- a/src/routes/(app)/+layout.svelte
+++ b/src/routes/(app)/+layout.svelte
@@ -13,7 +13,7 @@
 	import { getKnowledgeItems } from '$lib/apis/knowledge';
 	import { getFunctions } from '$lib/apis/functions';
 	import { getModels as _getModels, getVersionUpdates } from '$lib/apis';
-	import { getAllChatTags } from '$lib/apis/chats';
+	import { getAllTags } from '$lib/apis/chats';
 	import { getPrompts } from '$lib/apis/prompts';
 	import { getTools } from '$lib/apis/tools';
 	import { getBanners } from '$lib/apis/configs';
@@ -117,7 +117,7 @@
 					banners.set(await getBanners(localStorage.token));
 				})(),
 				(async () => {
-					tags.set(await getAllChatTags(localStorage.token));
+					tags.set(await getAllTags(localStorage.token));
 				})()
 			]);