mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Update chats.py
Add JSONB support to Postgres
This commit is contained in:
parent
bf7a18a0f8
commit
484133de4c
@ -2,7 +2,8 @@ import logging
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from enum import Enum
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
@ -12,6 +13,170 @@ from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
|
||||
# Import JSONB for PostgreSQL support
|
||||
try:
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
except ImportError:
|
||||
JSONB = None
|
||||
|
||||
####################
|
||||
# Database Adapter
|
||||
####################
|
||||
|
||||
class DatabaseType(Enum):
|
||||
SQLITE = "sqlite"
|
||||
POSTGRESQL_JSON = "postgresql_json"
|
||||
POSTGRESQL_JSONB = "postgresql_jsonb"
|
||||
UNSUPPORTED = "unsupported"
|
||||
|
||||
class DatabaseAdapter:
|
||||
"""Centralized database-specific query generation with caching"""
|
||||
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.dialect = db.bind.dialect.name
|
||||
self._cache: Dict[str, DatabaseType] = {}
|
||||
|
||||
def get_database_type(self, column_name: str = "meta") -> DatabaseType:
|
||||
"""Determine database type with caching"""
|
||||
cache_key = f"{self.dialect}_{column_name}"
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
if self.dialect == "sqlite":
|
||||
result = DatabaseType.SQLITE
|
||||
elif self.dialect == "postgresql":
|
||||
result = DatabaseType.POSTGRESQL_JSONB if self._is_jsonb_column(column_name) else DatabaseType.POSTGRESQL_JSON
|
||||
else:
|
||||
result = DatabaseType.UNSUPPORTED
|
||||
|
||||
self._cache[cache_key] = result
|
||||
return result
|
||||
|
||||
def _is_jsonb_column(self, column_name: str) -> bool:
|
||||
"""Check if column is JSONB type"""
|
||||
if JSONB is None or self.dialect != "postgresql":
|
||||
return False
|
||||
|
||||
try:
|
||||
result = self.db.execute(text("""
|
||||
SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name = 'chat' AND column_name = :column_name
|
||||
"""), {"column_name": column_name})
|
||||
|
||||
row = result.fetchone()
|
||||
return row[0].lower() == 'jsonb' if row else False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _get_function_template(self, db_type: DatabaseType, function_type: str) -> Optional[str]:
|
||||
"""Get function template for specific database type and function"""
|
||||
templates = {
|
||||
DatabaseType.SQLITE: {
|
||||
"tag_exists": "EXISTS (SELECT 1 FROM json_each({column}, '$.tags') WHERE json_each.value = :tag_id)",
|
||||
"has_key": "json_extract({column}, '$.{path}') IS NOT NULL",
|
||||
"array_length": "json_array_length({column}, '$.{path}')",
|
||||
"array_elements": "json_each({column}, '$.{path}')",
|
||||
"content_search": """EXISTS (
|
||||
SELECT 1 FROM json_each({column}, '$.messages') AS message
|
||||
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
|
||||
)"""
|
||||
},
|
||||
DatabaseType.POSTGRESQL_JSON: {
|
||||
"tag_exists": "EXISTS (SELECT 1 FROM json_array_elements_text({column}->'tags') elem WHERE elem = :tag_id)",
|
||||
"has_key": "{column} ? '{path}'",
|
||||
"array_length": "json_array_length({column}->'{path}')",
|
||||
"array_elements": "json_array_elements({column}->'{path}')",
|
||||
"content_search": """EXISTS (
|
||||
SELECT 1 FROM json_array_elements({column}->'messages') AS message
|
||||
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
|
||||
)"""
|
||||
},
|
||||
DatabaseType.POSTGRESQL_JSONB: {
|
||||
"tag_exists": "EXISTS (SELECT 1 FROM jsonb_array_elements_text({column}->'tags') elem WHERE elem = :tag_id)",
|
||||
"has_key": "{column} ? '{path}'",
|
||||
"array_length": "jsonb_array_length({column}->'{path}')",
|
||||
"array_elements": "jsonb_array_elements({column}->'{path}')",
|
||||
"content_search": """EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements({column}->'messages') AS message
|
||||
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
|
||||
)"""
|
||||
}
|
||||
}
|
||||
|
||||
return templates.get(db_type, {}).get(function_type)
|
||||
|
||||
def build_tag_filter(self, column_name: str, tag_ids: List[str], match_all: bool = True) -> Optional[Union[TextClause, and_, or_]]:
|
||||
"""Build optimized tag filtering query"""
|
||||
if not tag_ids:
|
||||
return None
|
||||
|
||||
db_type = self.get_database_type(column_name)
|
||||
template = self._get_function_template(db_type, "tag_exists")
|
||||
|
||||
if not template:
|
||||
return None
|
||||
|
||||
query_template = template.replace("{column}", f"Chat.{column_name}")
|
||||
|
||||
if match_all:
|
||||
return and_(*[
|
||||
text(query_template).params(tag_id=tag_id)
|
||||
for tag_id in tag_ids
|
||||
])
|
||||
else:
|
||||
conditions = []
|
||||
params = {}
|
||||
for idx, tag_id in enumerate(tag_ids):
|
||||
param_name = f"tag_id_{idx}"
|
||||
params[param_name] = tag_id
|
||||
condition_template = query_template.replace(":tag_id", f":{param_name}")
|
||||
conditions.append(text(condition_template))
|
||||
|
||||
return or_(*conditions).params(**params)
|
||||
|
||||
def build_search_filter(self, search_text: str) -> Optional[TextClause]:
|
||||
"""Build content search query"""
|
||||
db_type = self.get_database_type("chat")
|
||||
template = self._get_function_template(db_type, "content_search")
|
||||
|
||||
if not template:
|
||||
return None
|
||||
|
||||
query = template.replace("{column}", "Chat.chat")
|
||||
return text(query).params(search_text=search_text)
|
||||
|
||||
def build_untagged_filter(self, column_name: str = "meta") -> Optional[or_]:
|
||||
"""Build filter for chats without tags"""
|
||||
db_type = self.get_database_type(column_name)
|
||||
|
||||
has_key_template = self._get_function_template(db_type, "has_key")
|
||||
array_length_template = self._get_function_template(db_type, "array_length")
|
||||
|
||||
if not has_key_template or not array_length_template:
|
||||
return None
|
||||
|
||||
has_key = has_key_template.replace("{column}", f"Chat.{column_name}").replace("{path}", "tags")
|
||||
array_length = array_length_template.replace("{column}", f"Chat.{column_name}").replace("{path}", "tags")
|
||||
|
||||
return or_(
|
||||
text(f"NOT ({has_key})"),
|
||||
text(f"{array_length} = 0")
|
||||
)
|
||||
|
||||
####################
|
||||
# Utility Functions
|
||||
####################
|
||||
|
||||
def normalize_tag_name(tag_name: str) -> str:
|
||||
"""Normalize tag name for consistent storage and querying"""
|
||||
return tag_name.replace(" ", "_").lower()
|
||||
|
||||
def normalize_tag_names(tag_names: List[str]) -> List[str]:
|
||||
"""Normalize multiple tag names"""
|
||||
return [normalize_tag_name(tag) for tag in tag_names]
|
||||
|
||||
####################
|
||||
# Chat DB Schema
|
||||
@ -21,13 +186,15 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chat"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
title = Column(Text)
|
||||
chat = Column(JSON)
|
||||
chat = Column(JSON) # For JSONB support, change to: Column(JSONB) if JSONB else Column(JSON)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
@ -36,7 +203,7 @@ class Chat(Base):
|
||||
archived = Column(Boolean, default=False)
|
||||
pinned = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
meta = Column(JSON, server_default="{}")
|
||||
meta = Column(JSON, server_default="{}") # For JSONB support, change to: Column(JSONB, server_default="{}") if JSONB else Column(JSON, server_default="{}")
|
||||
folder_id = Column(Text, nullable=True)
|
||||
|
||||
|
||||
@ -105,6 +272,337 @@ class ChatTitleIdResponse(BaseModel):
|
||||
|
||||
|
||||
class ChatTable:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _get_adapter(self, db) -> DatabaseAdapter:
|
||||
"""Get database adapter for the current session"""
|
||||
return DatabaseAdapter(db)
|
||||
|
||||
# Legacy methods for backward compatibility
|
||||
def _is_jsonb_column(self, db, column_name: str) -> bool:
|
||||
"""Legacy method - use adapter instead"""
|
||||
adapter = self._get_adapter(db)
|
||||
return adapter.get_database_type(column_name) == DatabaseType.POSTGRESQL_JSONB
|
||||
|
||||
def _get_json_query_type(self, db, column_name: str = "meta") -> str:
|
||||
"""Legacy method - use adapter instead"""
|
||||
adapter = self._get_adapter(db)
|
||||
db_type = adapter.get_database_type(column_name)
|
||||
return db_type.value
|
||||
|
||||
def check_database_compatibility(self) -> dict:
|
||||
"""Check database compatibility and available features"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
adapter = self._get_adapter(db)
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
meta_type = adapter.get_database_type("meta")
|
||||
chat_type = adapter.get_database_type("chat")
|
||||
|
||||
compatibility = {
|
||||
"database_type": dialect_name,
|
||||
"json_support": meta_type != DatabaseType.UNSUPPORTED,
|
||||
"jsonb_support": meta_type == DatabaseType.POSTGRESQL_JSONB or chat_type == DatabaseType.POSTGRESQL_JSONB,
|
||||
"gin_indexes_support": dialect_name == "postgresql",
|
||||
"tag_filtering_support": meta_type != DatabaseType.UNSUPPORTED,
|
||||
"advanced_search_support": chat_type != DatabaseType.UNSUPPORTED,
|
||||
"meta_column_type": meta_type.value,
|
||||
"chat_column_type": chat_type.value,
|
||||
"features": [],
|
||||
"limitations": [],
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
# Add features based on database type
|
||||
if dialect_name == "sqlite":
|
||||
compatibility["features"] = ["JSON1 extension", "Basic tag filtering", "Message search"]
|
||||
compatibility["limitations"] = ["No GIN indexes", "Limited JSON optimization"]
|
||||
elif dialect_name == "postgresql":
|
||||
compatibility["features"] = ["Full JSON/JSONB support", "GIN indexes", "Advanced filtering"]
|
||||
if compatibility["jsonb_support"]:
|
||||
compatibility["features"].append("JSONB binary format optimization")
|
||||
|
||||
return compatibility
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error checking database compatibility: {e}")
|
||||
return {"error": str(e), "database_type": "unknown"}
|
||||
|
||||
|
||||
|
||||
def create_gin_indexes(self) -> bool:
|
||||
"""Create GIN indexes on JSONB columns for better query performance"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
adapter = self._get_adapter(db)
|
||||
|
||||
if db.bind.dialect.name != "postgresql":
|
||||
return False
|
||||
|
||||
meta_type = adapter.get_database_type("meta")
|
||||
chat_type = adapter.get_database_type("chat")
|
||||
|
||||
has_jsonb_meta = meta_type == DatabaseType.POSTGRESQL_JSONB
|
||||
has_jsonb_chat = chat_type == DatabaseType.POSTGRESQL_JSONB
|
||||
|
||||
if not (has_jsonb_meta or has_jsonb_chat):
|
||||
return False
|
||||
|
||||
# Create GIN indexes
|
||||
if has_jsonb_meta:
|
||||
try:
|
||||
db.execute(text("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_meta_gin
|
||||
ON chat USING GIN (meta)
|
||||
"""))
|
||||
db.execute(text("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_meta_tags_gin
|
||||
ON chat USING GIN ((meta->'tags'))
|
||||
"""))
|
||||
db.execute(text("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_has_tags
|
||||
ON chat USING BTREE ((meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0))
|
||||
WHERE meta ? 'tags'
|
||||
"""))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_jsonb_chat:
|
||||
try:
|
||||
db.execute(text("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_chat_gin
|
||||
ON chat USING GIN (chat)
|
||||
"""))
|
||||
db.execute(text("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_messages_gin
|
||||
ON chat USING GIN ((chat->'messages'))
|
||||
"""))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def check_gin_indexes(self) -> dict:
|
||||
"""
|
||||
Check which GIN indexes exist on the chat table.
|
||||
Returns a dictionary with index names and their status.
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
if db.bind.dialect.name != "postgresql":
|
||||
return {"error": "GIN indexes are only supported on PostgreSQL"}
|
||||
|
||||
result = db.execute(text("""
|
||||
SELECT indexname, indexdef
|
||||
FROM pg_indexes
|
||||
WHERE tablename = 'chat'
|
||||
AND indexdef LIKE '%USING gin%'
|
||||
"""))
|
||||
|
||||
indexes = {}
|
||||
for row in result:
|
||||
indexes[row[0]] = {
|
||||
"exists": True,
|
||||
"definition": row[1]
|
||||
}
|
||||
|
||||
# Check for expected indexes
|
||||
expected_indexes = [
|
||||
"idx_chat_meta_gin",
|
||||
"idx_chat_chat_gin",
|
||||
"idx_chat_meta_tags_gin",
|
||||
"idx_chat_has_tags",
|
||||
"idx_chat_tag_count",
|
||||
"idx_chat_json_tags",
|
||||
"idx_chat_messages_gin"
|
||||
]
|
||||
|
||||
for idx_name in expected_indexes:
|
||||
if idx_name not in indexes:
|
||||
indexes[idx_name] = {"exists": False}
|
||||
|
||||
return indexes
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error checking GIN indexes: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def drop_gin_indexes(self) -> bool:
|
||||
"""Drop all GIN indexes on the chat table"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
if db.bind.dialect.name != "postgresql":
|
||||
return False
|
||||
|
||||
indexes_to_drop = [
|
||||
"idx_chat_meta_gin",
|
||||
"idx_chat_chat_gin",
|
||||
"idx_chat_meta_tags_gin",
|
||||
"idx_chat_has_tags",
|
||||
"idx_chat_tag_count",
|
||||
"idx_chat_json_tags",
|
||||
"idx_chat_messages_gin"
|
||||
]
|
||||
|
||||
for idx_name in indexes_to_drop:
|
||||
try:
|
||||
db.execute(text(f"DROP INDEX CONCURRENTLY IF EXISTS {idx_name}"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def create_tag_indexes(self) -> bool:
|
||||
"""
|
||||
Create specialized indexes optimized specifically for tag operations.
|
||||
This includes both GIN and BTREE indexes for different tag query patterns.
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
if db.bind.dialect.name != "postgresql":
|
||||
log.info("Tag indexes are only supported on PostgreSQL")
|
||||
return False
|
||||
|
||||
has_jsonb_meta = self._is_jsonb_column(db, "meta")
|
||||
indexes_created = []
|
||||
|
||||
if has_jsonb_meta:
|
||||
# JSONB-specific tag indexes
|
||||
tag_indexes = [
|
||||
{
|
||||
"name": "idx_chat_meta_tags_gin",
|
||||
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_meta_tags_gin ON chat USING GIN ((meta->'tags'))",
|
||||
"purpose": "Fast tag containment queries (@>, ?, etc.)"
|
||||
},
|
||||
{
|
||||
"name": "idx_chat_has_tags",
|
||||
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_has_tags ON chat USING BTREE ((meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0)) WHERE meta ? 'tags'",
|
||||
"purpose": "Fast filtering for chats with/without tags"
|
||||
},
|
||||
{
|
||||
"name": "idx_chat_tag_count",
|
||||
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_tag_count ON chat USING BTREE ((jsonb_array_length(meta->'tags'))) WHERE meta ? 'tags'",
|
||||
"purpose": "Fast filtering by number of tags"
|
||||
},
|
||||
{
|
||||
"name": "idx_chat_specific_tags",
|
||||
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_specific_tags ON chat USING GIN ((meta->'tags')) WHERE jsonb_array_length(meta->'tags') > 0",
|
||||
"purpose": "Optimized for chats that actually have tags"
|
||||
}
|
||||
]
|
||||
else:
|
||||
# JSON-specific tag indexes (less optimal but still helpful)
|
||||
tag_indexes = [
|
||||
{
|
||||
"name": "idx_chat_json_tags",
|
||||
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_json_tags ON chat USING GIN ((meta->'tags')) WHERE meta ? 'tags'",
|
||||
"purpose": "Tag queries for JSON columns"
|
||||
}
|
||||
]
|
||||
|
||||
for index_info in tag_indexes:
|
||||
try:
|
||||
db.execute(text(index_info["sql"]))
|
||||
indexes_created.append(f"{index_info['name']} ({index_info['purpose']})")
|
||||
log.info(f"Created tag index: {index_info['name']}")
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to create {index_info['name']}: {e}")
|
||||
|
||||
db.commit()
|
||||
|
||||
if indexes_created:
|
||||
log.info(f"Successfully created tag indexes: {len(indexes_created)} indexes")
|
||||
for idx in indexes_created:
|
||||
log.info(f" • {idx}")
|
||||
else:
|
||||
log.info("No tag indexes were created")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error creating tag indexes: {e}")
|
||||
return False
|
||||
|
||||
def optimize_tag_queries(self) -> dict:
|
||||
"""
|
||||
Analyze and provide recommendations for tag query optimization.
|
||||
Returns statistics and suggestions for improving tag query performance.
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
if db.bind.dialect.name != "postgresql":
|
||||
return {"error": "Tag optimization is only supported on PostgreSQL"}
|
||||
|
||||
stats = {}
|
||||
|
||||
# Get basic tag statistics
|
||||
result = db.execute(text("""
|
||||
SELECT
|
||||
COUNT(*) as total_chats,
|
||||
COUNT(*) FILTER (WHERE meta ? 'tags') as chats_with_tags,
|
||||
COUNT(*) FILTER (WHERE meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0) as chats_with_actual_tags,
|
||||
AVG(CASE WHEN meta ? 'tags' THEN jsonb_array_length(meta->'tags') ELSE 0 END) as avg_tags_per_chat
|
||||
FROM chat
|
||||
"""))
|
||||
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
stats.update({
|
||||
"total_chats": row[0],
|
||||
"chats_with_tags": row[1],
|
||||
"chats_with_actual_tags": row[2],
|
||||
"avg_tags_per_chat": float(row[3]) if row[3] else 0
|
||||
})
|
||||
|
||||
# Get most common tags
|
||||
result = db.execute(text("""
|
||||
SELECT tag_value, COUNT(*) as usage_count
|
||||
FROM chat, jsonb_array_elements_text(meta->'tags') as tag_value
|
||||
WHERE meta ? 'tags'
|
||||
GROUP BY tag_value
|
||||
ORDER BY usage_count DESC
|
||||
LIMIT 10
|
||||
"""))
|
||||
|
||||
stats["top_tags"] = [{"tag": row[0], "count": row[1]} for row in result]
|
||||
|
||||
# Check index usage
|
||||
indexes = self.check_gin_indexes()
|
||||
tag_indexes = {k: v for k, v in indexes.items() if "tag" in k.lower()}
|
||||
stats["tag_indexes"] = tag_indexes
|
||||
|
||||
# Provide recommendations
|
||||
recommendations = []
|
||||
|
||||
if stats["chats_with_actual_tags"] > 1000:
|
||||
recommendations.append("Consider creating tag-specific indexes for better performance")
|
||||
|
||||
if stats["avg_tags_per_chat"] > 5:
|
||||
recommendations.append("High tag usage detected - GIN indexes will provide significant benefits")
|
||||
|
||||
tag_coverage = stats["chats_with_actual_tags"] / stats["total_chats"] if stats["total_chats"] > 0 else 0
|
||||
if tag_coverage < 0.1:
|
||||
recommendations.append("Low tag usage - consider partial indexes with WHERE clauses")
|
||||
|
||||
stats["recommendations"] = recommendations
|
||||
stats["tag_coverage_percentage"] = round(tag_coverage * 100, 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error analyzing tag queries: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
@ -537,8 +1035,10 @@ class ChatTable:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
# .limit(limit).offset(skip)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
@ -591,7 +1091,7 @@ class ChatTable:
|
||||
|
||||
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
|
||||
tag_ids = [
|
||||
word.replace("tag:", "").replace(" ", "_").lower()
|
||||
normalize_tag_name(word.replace("tag:", ""))
|
||||
for word in search_text_words
|
||||
if word.startswith("tag:")
|
||||
]
|
||||
@ -610,115 +1110,33 @@ class ChatTable:
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
# Check if the database dialect is either 'sqlite' or 'postgresql'
|
||||
dialect_name = db.bind.dialect.name
|
||||
if dialect_name == "sqlite":
|
||||
# SQLite case: using JSON1 extension for JSON searching
|
||||
query = query.filter(
|
||||
(
|
||||
Chat.title.ilike(
|
||||
f"%{search_text}%"
|
||||
) # Case-insensitive search in title
|
||||
| text(
|
||||
"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.chat, '$.messages') AS message
|
||||
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if "none" in tag_ids:
|
||||
# Use adapter for cleaner query building
|
||||
adapter = self._get_adapter(db)
|
||||
|
||||
# Add search filter if search text provided
|
||||
if search_text:
|
||||
search_filter = adapter.build_search_filter(search_text)
|
||||
if search_filter is not None:
|
||||
query = query.filter(
|
||||
text(
|
||||
"""
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.meta, '$.tags') AS tag
|
||||
)
|
||||
"""
|
||||
)
|
||||
Chat.title.ilike(f"%{search_text}%") | search_filter
|
||||
)
|
||||
elif tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
text(
|
||||
f"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.meta, '$.tags') AS tag
|
||||
WHERE tag.value = :tag_id_{tag_idx}
|
||||
)
|
||||
"""
|
||||
).params(**{f"tag_id_{tag_idx}": tag_id})
|
||||
for tag_idx, tag_id in enumerate(tag_ids)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
elif dialect_name == "postgresql":
|
||||
# PostgreSQL relies on proper JSON query for search
|
||||
query = query.filter(
|
||||
(
|
||||
Chat.title.ilike(
|
||||
f"%{search_text}%"
|
||||
) # Case-insensitive search in title
|
||||
| text(
|
||||
"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements(Chat.chat->'messages') AS message
|
||||
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if "none" in tag_ids:
|
||||
query = query.filter(
|
||||
text(
|
||||
"""
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements_text(Chat.meta->'tags') AS tag
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
elif tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
text(
|
||||
f"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements_text(Chat.meta->'tags') AS tag
|
||||
WHERE tag = :tag_id_{tag_idx}
|
||||
)
|
||||
"""
|
||||
).params(**{f"tag_id_{tag_idx}": tag_id})
|
||||
for tag_idx, tag_id in enumerate(tag_ids)
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
else:
|
||||
# Fallback to title-only search for unsupported databases
|
||||
query = query.filter(Chat.title.ilike(f"%{search_text}%"))
|
||||
|
||||
# Add tag filters
|
||||
if "none" in tag_ids:
|
||||
untagged_filter = adapter.build_untagged_filter("meta")
|
||||
if untagged_filter is not None:
|
||||
query = query.filter(untagged_filter)
|
||||
elif tag_ids:
|
||||
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=True)
|
||||
if tag_filter is not None:
|
||||
query = query.filter(tag_filter)
|
||||
|
||||
# Perform pagination at the SQL level
|
||||
all_chats = query.offset(skip).limit(limit).all()
|
||||
|
||||
log.info(f"The number of chats: {len(all_chats)}")
|
||||
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
@ -775,32 +1193,59 @@ class ChatTable:
|
||||
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
adapter = self._get_adapter(db)
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
tag_id = normalize_tag_name(tag_name)
|
||||
|
||||
log.info(f"DB dialect name: {db.bind.dialect.name}")
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 querying for tags within the meta JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
|
||||
query = query.filter(
|
||||
text(
|
||||
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
# Use adapter to build tag filter
|
||||
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
|
||||
if tag_filter is not None:
|
||||
query = query.filter(tag_filter)
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
return []
|
||||
|
||||
all_chats = query.all()
|
||||
log.debug(f"all_chats: {all_chats}")
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
def get_chats_by_multiple_tags(
|
||||
self, user_id: str, tag_names: List[str], match_all: bool = True, skip: int = 0, limit: int = 50
|
||||
) -> list[ChatModel]:
|
||||
"""Get chats that match multiple tags"""
|
||||
with get_db() as db:
|
||||
adapter = self._get_adapter(db)
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
|
||||
|
||||
if not tag_names:
|
||||
return []
|
||||
|
||||
# Normalize tag names
|
||||
tag_ids = normalize_tag_names(tag_names)
|
||||
|
||||
# Use adapter to build tag filter
|
||||
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=match_all)
|
||||
if tag_filter is not None:
|
||||
query = query.filter(tag_filter)
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_chats_without_tags(self, user_id: str, skip: int = 0, limit: int = 50) -> list[ChatModel]:
|
||||
"""Get chats that have no tags"""
|
||||
with get_db() as db:
|
||||
adapter = self._get_adapter(db)
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
|
||||
|
||||
# Use adapter to build untagged filter
|
||||
untagged_filter = adapter.build_untagged_filter("meta")
|
||||
if untagged_filter is not None:
|
||||
query = query.filter(untagged_filter)
|
||||
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
else:
|
||||
return []
|
||||
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
@ -826,40 +1271,20 @@ class ChatTable:
|
||||
return None
|
||||
|
||||
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
|
||||
with get_db() as db: # Assuming `get_db()` returns a session object
|
||||
with get_db() as db:
|
||||
adapter = self._get_adapter(db)
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
|
||||
|
||||
# Normalize the tag_name for consistency
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
tag_id = normalize_tag_name(tag_name)
|
||||
|
||||
# Use adapter to build tag filter
|
||||
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
|
||||
if tag_filter is not None:
|
||||
query = query.filter(tag_filter)
|
||||
return query.count()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
# Get the count of matching records
|
||||
count = query.count()
|
||||
|
||||
# Debugging output for inspection
|
||||
log.info(f"Count of chats for tag '{tag_name}': {count}")
|
||||
|
||||
return count
|
||||
return 0
|
||||
|
||||
def delete_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
@ -868,7 +1293,7 @@ class ChatTable:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
tag_id = normalize_tag_name(tag_name)
|
||||
|
||||
tags = [tag for tag in tags if tag != tag_id]
|
||||
chat.meta = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user