refac: move search to backend

This commit is contained in:
Timothy J. Baek 2024-10-08 23:37:37 -07:00
parent e66619262a
commit 37fdb0ea2e
5 changed files with 203 additions and 104 deletions

View File

@ -6,6 +6,8 @@ from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db from open_webui.apps.webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select
#################### ####################
# Chat DB Schema # Chat DB Schema
@ -249,10 +251,10 @@ class ChatTable:
Chat.id, Chat.title, Chat.updated_at, Chat.created_at Chat.id, Chat.title, Chat.updated_at, Chat.created_at
) )
if limit:
query = query.limit(limit)
if skip: if skip:
query = query.offset(skip) query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all() all_chats = query.all()
@ -337,6 +339,50 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text(
self,
user_id: str,
search_text: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived:
query = query.filter(Chat.archived == False)
# Fetch all potentially relevant chats
all_chats = query.all()
# Filter chats using Python
filtered_chats = []
for chat in all_chats:
# Check chat title
title_matches = search_text in chat.title.lower()
# Check chat content in chat JSON
content_matches = any(
search_text in message.get("content", "").lower()
for message in chat.chat.get("messages", [])
if "content" in message
)
if title_matches or content_matches:
filtered_chats.append(chat)
# Implementing pagination manually
paginated_chats = filtered_chats[skip : skip + limit]
return [ChatModel.model_validate(chat) for chat in paginated_chats]
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:

View File

@ -108,6 +108,29 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
############################ ############################
@router.get("/search", response_model=list[ChatTitleIdResponse])
async def search_user_chats(
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
):
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
return [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit
)
]
############################
# GetChats
############################
@router.get("/all", response_model=list[ChatResponse]) @router.get("/all", response_model=list[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)): async def get_user_chats(user=Depends(get_verified_user)):
return [ return [

View File

@ -302,6 +302,12 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
) )
####################################
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
#################################### ####################################
# WEBUI_AUTH (Required for security) # WEBUI_AUTH (Required for security)
#################################### ####################################
@ -343,8 +349,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0") WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")

View File

@ -167,6 +167,41 @@ export const getAllChats = async (token: string) => {
return res; return res;
}; };
export const getChatListBySearchText = async (token: string, text: string, page: number = 1) => {
let error = null;
const searchParams = new URLSearchParams();
searchParams.append('text', text);
searchParams.append('page', `${page}`);
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/search?${searchParams.toString()}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getAllArchivedChats = async (token: string) => { export const getAllArchivedChats = async (token: string) => {
let error = null; let error = null;

View File

@ -32,7 +32,8 @@
updateChatById, updateChatById,
getAllChatTags, getAllChatTags,
archiveChatById, archiveChatById,
cloneChatById cloneChatById,
getChatListBySearchText
} from '$lib/apis/chats'; } from '$lib/apis/chats';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
@ -58,33 +59,11 @@
let selectedTagName = null; let selectedTagName = null;
let filteredChatList = [];
// Pagination variables // Pagination variables
let chatListLoading = false; let chatListLoading = false;
let allChatsLoaded = false; let allChatsLoaded = false;
$: filteredChatList = $chats.filter((chat) => { const initChatList = async () => {
if (search === '') {
return true;
} else {
let title = chat.title.toLowerCase();
const query = search.toLowerCase();
let contentMatches = false;
// Access the messages within chat.chat.messages
if (chat.chat && chat.chat.messages && Array.isArray(chat.chat.messages)) {
contentMatches = chat.chat.messages.some((message) => {
// Check if message.content exists and includes the search query
return message.content && message.content.toLowerCase().includes(query);
});
}
return title.includes(query) || contentMatches;
}
});
const enablePagination = async () => {
// Reset pagination variables // Reset pagination variables
currentChatPage.set(1); currentChatPage.set(1);
allChatsLoaded = false; allChatsLoaded = false;
@ -98,7 +77,14 @@
chatListLoading = true; chatListLoading = true;
currentChatPage.set($currentChatPage + 1); currentChatPage.set($currentChatPage + 1);
const newChatList = await getChatList(localStorage.token, $currentChatPage);
let newChatList = [];
if (search) {
newChatList = await getChatListBySearchText(localStorage.token, search, $currentChatPage);
} else {
newChatList = await getChatList(localStorage.token, $currentChatPage);
}
// once the bottom of the list has been reached (no results) there is no need to continue querying // once the bottom of the list has been reached (no results) there is no need to continue querying
allChatsLoaded = newChatList.length === 0; allChatsLoaded = newChatList.length === 0;
@ -107,6 +93,28 @@
chatListLoading = false; chatListLoading = false;
}; };
let searchDebounceTimeout;
const searchDebounceHandler = async () => {
console.log('search', search);
chats.set(null);
selectedTagName = null;
if (searchDebounceTimeout) {
clearTimeout(searchDebounceTimeout);
}
if (search === '') {
await initChatList();
return;
} else {
searchDebounceTimeout = setTimeout(async () => {
currentChatPage.set(1);
await chats.set(await getChatListBySearchText(localStorage.token, search));
}, 1000);
}
};
onMount(async () => { onMount(async () => {
mobile.subscribe((e) => { mobile.subscribe((e) => {
if ($showSidebar && e) { if ($showSidebar && e) {
@ -124,7 +132,7 @@
}); });
await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
await enablePagination(); await initChatList();
let touchstart; let touchstart;
let touchend; let touchend;
@ -192,27 +200,6 @@
}; };
}); });
// Helper function to fetch and add chat content to each chat
const enrichChatsWithContent = async (chatList) => {
const enrichedChats = await Promise.all(
chatList.map(async (chat) => {
const chatDetails = await getChatById(localStorage.token, chat.id).catch((error) => null); // Handle error or non-existent chat gracefully
if (chatDetails) {
chat.chat = chatDetails.chat; // Assuming chatDetails.chat contains the chat content
}
return chat;
})
);
await chats.set(enrichedChats);
};
const saveSettings = async (updated) => {
await settings.set({ ...$settings, ...updated });
await updateUserSettings(localStorage.token, { ui: $settings });
location.href = '/';
};
const deleteChatHandler = async (id) => { const deleteChatHandler = async (id) => {
const res = await deleteChatById(localStorage.token, id).catch((error) => { const res = await deleteChatById(localStorage.token, id).catch((error) => {
toast.error(error); toast.error(error);
@ -419,11 +406,8 @@
class="w-full rounded-r-xl py-1.5 pl-2.5 pr-4 text-sm bg-transparent dark:text-gray-300 outline-none" class="w-full rounded-r-xl py-1.5 pl-2.5 pr-4 text-sm bg-transparent dark:text-gray-300 outline-none"
placeholder={$i18n.t('Search')} placeholder={$i18n.t('Search')}
bind:value={search} bind:value={search}
on:focus={async () => { on:input={() => {
// TODO: migrate backend for more scalable search mechanism searchDebounceHandler();
scrollPaginationEnabled.set(false);
await chats.set(await getChatList(localStorage.token)); // when searching, load all chats
enrichChatsWithContent($chats);
}} }}
/> />
</div> </div>
@ -437,7 +421,7 @@
: ' '} rounded-md font-medium" : ' '} rounded-md font-medium"
on:click={async () => { on:click={async () => {
selectedTagName = null; selectedTagName = null;
await enablePagination(); await initChatList();
}} }}
> >
{$i18n.t('all')} {$i18n.t('all')}
@ -455,10 +439,9 @@
await tags.set(await getAllChatTags(localStorage.token)); await tags.set(await getAllChatTags(localStorage.token));
// if the tag we deleted is no longer a valid tag, return to main chat list view // if the tag we deleted is no longer a valid tag, return to main chat list view
await enablePagination(); await initChatList();
} }
await chats.set(chatIds); await chats.set(chatIds);
chatListLoading = false; chatListLoading = false;
}} }}
> >
@ -501,8 +484,9 @@
{/if} {/if}
<div class="pl-2 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto scrollbar-hidden"> <div class="pl-2 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto scrollbar-hidden">
{#each filteredChatList as chat, idx} {#if $chats}
{#if idx === 0 || (idx > 0 && chat.time_range !== filteredChatList[idx - 1].time_range)} {#each $chats as chat, idx}
{#if idx === 0 || (idx > 0 && chat.time_range !== $chats[idx - 1].time_range)}
<div <div
class="w-full pl-2.5 text-xs text-gray-500 dark:text-gray-500 font-medium {idx === 0 class="w-full pl-2.5 text-xs text-gray-500 dark:text-gray-500 font-medium {idx === 0
? '' ? ''
@ -565,6 +549,12 @@
</div> </div>
</Loader> </Loader>
{/if} {/if}
{:else}
<div class="w-full flex justify-center py-1 text-xs animate-pulse items-center gap-2">
<Spinner className=" size-4" />
<div class=" ">Loading...</div>
</div>
{/if}
</div> </div>
</div> </div>