refac: move search to backend

This commit is contained in:
Timothy J. Baek
2024-10-08 23:37:37 -07:00
parent e66619262a
commit 37fdb0ea2e
5 changed files with 203 additions and 104 deletions

View File

@@ -6,6 +6,8 @@ from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select
####################
# Chat DB Schema
@@ -249,10 +251,10 @@ class ChatTable:
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
)
if limit:
query = query.limit(limit)
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
@@ -337,6 +339,50 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text(
self,
user_id: str,
search_text: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived:
query = query.filter(Chat.archived == False)
# Fetch all potentially relevant chats
all_chats = query.all()
# Filter chats using Python
filtered_chats = []
for chat in all_chats:
# Check chat title
title_matches = search_text in chat.title.lower()
# Check chat content in chat JSON
content_matches = any(
search_text in message.get("content", "").lower()
for message in chat.chat.get("messages", [])
if "content" in message
)
if title_matches or content_matches:
filtered_chats.append(chat)
# Implementing pagination manually
paginated_chats = filtered_chats[skip : skip + limit]
return [ChatModel.model_validate(chat) for chat in paginated_chats]
def delete_chat_by_id(self, id: str) -> bool:
try:
with get_db() as db:

View File

@@ -108,6 +108,29 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
############################
@router.get("/search", response_model=list[ChatTitleIdResponse])
async def search_user_chats(
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
):
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
return [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit
)
]
############################
# GetChats
############################
@router.get("/all", response_model=list[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)):
return [

View File

@@ -302,6 +302,12 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
####################################
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
####################################
# WEBUI_AUTH (Required for security)
####################################
@@ -343,8 +349,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")