From 37fdb0ea2e3b4e9fe5a32e6c57c5e3e9c1db7d9f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 8 Oct 2024 23:37:37 -0700 Subject: [PATCH] refac: move search to backend --- backend/open_webui/apps/webui/models/chats.py | 50 ++++- .../open_webui/apps/webui/routers/chats.py | 23 +++ backend/open_webui/env.py | 9 +- src/lib/apis/chats/index.ts | 35 ++++ src/lib/components/layout/Sidebar.svelte | 190 +++++++++--------- 5 files changed, 203 insertions(+), 104 deletions(-) diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py index 6a79b6ae6..4109bfa46 100644 --- a/backend/open_webui/apps/webui/models/chats.py +++ b/backend/open_webui/apps/webui/models/chats.py @@ -6,6 +6,8 @@ from typing import Optional from open_webui.apps.webui.internal.db import Base, get_db from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import or_, func, select + #################### # Chat DB Schema @@ -249,10 +251,10 @@ class ChatTable: Chat.id, Chat.title, Chat.updated_at, Chat.created_at ) - if limit: - query = query.limit(limit) if skip: query = query.offset(skip) + if limit: + query = query.limit(limit) all_chats = query.all() @@ -337,6 +339,50 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chats_by_user_id_and_search_text( + self, + user_id: str, + search_text: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 60, + ) -> list[ChatModel]: + """ + Filters chats based on a search query using Python, allowing pagination using skip and limit. + """ + search_text = search_text.lower().strip() + if not search_text: + return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) + + with get_db() as db: + query = db.query(Chat).filter(Chat.user_id == user_id) + + if not include_archived: + query = query.filter(Chat.archived == False) + + # Fetch all potentially relevant chats + all_chats = query.all() + + # Filter chats using Python + filtered_chats = [] + for chat in all_chats: + # Check chat title + title_matches = search_text in chat.title.lower() + + # Check chat content in chat JSON + content_matches = any( + search_text in message.get("content", "").lower() + for message in chat.chat.get("messages", []) + if "content" in message + ) + + if title_matches or content_matches: + filtered_chats.append(chat) + + # Implementing pagination manually + paginated_chats = filtered_chats[skip : skip + limit] + return [ChatModel.model_validate(chat) for chat in paginated_chats] + def delete_chat_by_id(self, id: str) -> bool: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py index 01d99cfd8..f28b15206 100644 --- a/backend/open_webui/apps/webui/routers/chats.py +++ b/backend/open_webui/apps/webui/routers/chats.py @@ -108,6 +108,29 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): ############################ +@router.get("/search", response_model=list[ChatTitleIdResponse]) +async def search_user_chats( + text: str, page: Optional[int] = None, user=Depends(get_verified_user) +): + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + return [ + ChatTitleIdResponse(**chat.model_dump()) + for chat in Chats.get_chats_by_user_id_and_search_text( + user.id, text, skip=skip, limit=limit + ) + ] + + +############################ +# GetChats +############################ + + @router.get("/all", response_model=list[ChatResponse]) async def get_user_chats(user=Depends(get_verified_user)): return [ diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index ab2cef1fe..0f2ecada0 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -302,6 +302,12 @@ RESET_CONFIG_ON_START = ( os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" ) +#################################### +# REDIS +#################################### + +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + #################################### # WEBUI_AUTH (Required for security) #################################### @@ -343,8 +349,7 @@ ENABLE_WEBSOCKET_SUPPORT = ( WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") -WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0") - +WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 8f4f81aea..ac15f263d 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -167,6 +167,41 @@ export const getAllChats = async (token: string) => { return res; }; +export const getChatListBySearchText = async (token: string, text: string, page: number = 1) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('text', text); + searchParams.append('page', `${page}`); + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/search?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getAllArchivedChats = async (token: string) => { let error = null; diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 0582eb574..47af56834 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -32,7 +32,8 @@ updateChatById, getAllChatTags, archiveChatById, - cloneChatById + cloneChatById, + getChatListBySearchText } from '$lib/apis/chats'; import { WEBUI_BASE_URL } from '$lib/constants'; @@ -58,33 +59,11 @@ let selectedTagName = null; - let filteredChatList = []; - // Pagination variables let chatListLoading = false; let allChatsLoaded = false; - $: filteredChatList = $chats.filter((chat) => { - if (search === '') { - return true; - } else { - let title = chat.title.toLowerCase(); - const query = search.toLowerCase(); - - let contentMatches = false; - // Access the messages within chat.chat.messages - if (chat.chat && chat.chat.messages && Array.isArray(chat.chat.messages)) { - contentMatches = chat.chat.messages.some((message) => { - // Check if message.content exists and includes the search query - return message.content && message.content.toLowerCase().includes(query); - }); - } - - return title.includes(query) || contentMatches; - } - }); - - const enablePagination = async () => { + const initChatList = async () => { // Reset pagination variables currentChatPage.set(1); allChatsLoaded = false; @@ -98,7 +77,14 @@ chatListLoading = true; currentChatPage.set($currentChatPage + 1); - const newChatList = await getChatList(localStorage.token, $currentChatPage); + + let newChatList = []; + + if (search) { + newChatList = await getChatListBySearchText(localStorage.token, search, $currentChatPage); + } else { + newChatList = await getChatList(localStorage.token, $currentChatPage); + } // once the bottom of the list has been reached (no results) there is no need to continue querying allChatsLoaded = newChatList.length === 0; @@ -107,6 +93,28 @@ chatListLoading = false; }; + let searchDebounceTimeout; + + const searchDebounceHandler = async () => { + console.log('search', search); + chats.set(null); + selectedTagName = null; + + if (searchDebounceTimeout) { + clearTimeout(searchDebounceTimeout); + } + + if (search === '') { + await initChatList(); + return; + } else { + searchDebounceTimeout = setTimeout(async () => { + currentChatPage.set(1); + await chats.set(await getChatListBySearchText(localStorage.token, search)); + }, 1000); + } + }; + onMount(async () => { mobile.subscribe((e) => { if ($showSidebar && e) { @@ -124,7 +132,7 @@ }); await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); - await enablePagination(); + await initChatList(); let touchstart; let touchend; @@ -192,27 +200,6 @@ }; }); - // Helper function to fetch and add chat content to each chat - const enrichChatsWithContent = async (chatList) => { - const enrichedChats = await Promise.all( - chatList.map(async (chat) => { - const chatDetails = await getChatById(localStorage.token, chat.id).catch((error) => null); // Handle error or non-existent chat gracefully - if (chatDetails) { - chat.chat = chatDetails.chat; // Assuming chatDetails.chat contains the chat content - } - return chat; - }) - ); - - await chats.set(enrichedChats); - }; - - const saveSettings = async (updated) => { - await settings.set({ ...$settings, ...updated }); - await updateUserSettings(localStorage.token, { ui: $settings }); - location.href = '/'; - }; - const deleteChatHandler = async (id) => { const res = await deleteChatById(localStorage.token, id).catch((error) => { toast.error(error); @@ -419,11 +406,8 @@ class="w-full rounded-r-xl py-1.5 pl-2.5 pr-4 text-sm bg-transparent dark:text-gray-300 outline-none" placeholder={$i18n.t('Search')} bind:value={search} - on:focus={async () => { - // TODO: migrate backend for more scalable search mechanism - scrollPaginationEnabled.set(false); - await chats.set(await getChatList(localStorage.token)); // when searching, load all chats - enrichChatsWithContent($chats); + on:input={() => { + searchDebounceHandler(); }} /> @@ -437,7 +421,7 @@ : ' '} rounded-md font-medium" on:click={async () => { selectedTagName = null; - await enablePagination(); + await initChatList(); }} > {$i18n.t('all')} @@ -455,10 +439,9 @@ await tags.set(await getAllChatTags(localStorage.token)); // if the tag we deleted is no longer a valid tag, return to main chat list view - await enablePagination(); + await initChatList(); } await chats.set(chatIds); - chatListLoading = false; }} > @@ -501,15 +484,16 @@ {/if}
- {#each filteredChatList as chat, idx} - {#if idx === 0 || (idx > 0 && chat.time_range !== filteredChatList[idx - 1].time_range)} -
- {$i18n.t(chat.time_range)} - -
+
+ {/if} + + { + selectedChatId = chat.id; + }} + on:unselect={() => { + selectedChatId = null; + }} + on:delete={(e) => { + if ((e?.detail ?? '') === 'shift') { + deleteChatHandler(chat.id); + } else { + deleteChat = chat; + showDeleteConfirm = true; + } + }} + /> + {/each} + + {#if $scrollPaginationEnabled && !allChatsLoaded} + { + if (!chatListLoading) { + loadMoreChats(); + } + }} + > +
+ +
Loading...
+
+
{/if} - - { - selectedChatId = chat.id; - }} - on:unselect={() => { - selectedChatId = null; - }} - on:delete={(e) => { - if ((e?.detail ?? '') === 'shift') { - deleteChatHandler(chat.id); - } else { - deleteChat = chat; - showDeleteConfirm = true; - } - }} - /> - {/each} - - {#if $scrollPaginationEnabled && !allChatsLoaded} - { - if (!chatListLoading) { - loadMoreChats(); - } - }} - > -
- -
Loading...
-
-
+ {:else} +
+ +
Loading...
+
{/if}