mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: move search to backend
This commit is contained in:
@@ -6,6 +6,8 @@ from typing import Optional
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select
|
||||
|
||||
|
||||
####################
|
||||
# Chat DB Schema
|
||||
@@ -249,10 +251,10 @@ class ChatTable:
|
||||
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
|
||||
)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
all_chats = query.all()
|
||||
|
||||
@@ -337,6 +339,50 @@ class ChatTable:
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_user_id_and_search_text(
|
||||
self,
|
||||
user_id: str,
|
||||
search_text: str,
|
||||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 60,
|
||||
) -> list[ChatModel]:
|
||||
"""
|
||||
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
||||
"""
|
||||
search_text = search_text.lower().strip()
|
||||
if not search_text:
|
||||
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
|
||||
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter(Chat.user_id == user_id)
|
||||
|
||||
if not include_archived:
|
||||
query = query.filter(Chat.archived == False)
|
||||
|
||||
# Fetch all potentially relevant chats
|
||||
all_chats = query.all()
|
||||
|
||||
# Filter chats using Python
|
||||
filtered_chats = []
|
||||
for chat in all_chats:
|
||||
# Check chat title
|
||||
title_matches = search_text in chat.title.lower()
|
||||
|
||||
# Check chat content in chat JSON
|
||||
content_matches = any(
|
||||
search_text in message.get("content", "").lower()
|
||||
for message in chat.chat.get("messages", [])
|
||||
if "content" in message
|
||||
)
|
||||
|
||||
if title_matches or content_matches:
|
||||
filtered_chats.append(chat)
|
||||
|
||||
# Implementing pagination manually
|
||||
paginated_chats = filtered_chats[skip : skip + limit]
|
||||
return [ChatModel.model_validate(chat) for chat in paginated_chats]
|
||||
|
||||
def delete_chat_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -108,6 +108,29 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/search", response_model=list[ChatTitleIdResponse])
|
||||
async def search_user_chats(
|
||||
text: str, page: Optional[int] = None, user=Depends(get_verified_user)
|
||||
):
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
return [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_user_id_and_search_text(
|
||||
user.id, text, skip=skip, limit=limit
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/all", response_model=list[ChatResponse])
|
||||
async def get_user_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
|
||||
@@ -302,6 +302,12 @@ RESET_CONFIG_ON_START = (
|
||||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||
)
|
||||
|
||||
####################################
|
||||
# REDIS
|
||||
####################################
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
####################################
|
||||
# WEBUI_AUTH (Required for security)
|
||||
####################################
|
||||
@@ -343,8 +349,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
||||
|
||||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user