Update chats.py

Add JSONB support to Postgres
This commit is contained in:
PVBLIC Foundation 2025-05-27 17:43:04 -07:00 committed by GitHub
parent bf7a18a0f8
commit 484133de4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,8 @@ import logging
import json
import time
import uuid
from typing import Optional
from typing import Optional, Dict, Any, List, Union
from enum import Enum
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
@ -12,6 +13,170 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
from sqlalchemy.sql.elements import TextClause
# Import JSONB for PostgreSQL support
try:
from sqlalchemy.dialects.postgresql import JSONB
except ImportError:
JSONB = None
####################
# Database Adapter
####################
class DatabaseType(Enum):
SQLITE = "sqlite"
POSTGRESQL_JSON = "postgresql_json"
POSTGRESQL_JSONB = "postgresql_jsonb"
UNSUPPORTED = "unsupported"
class DatabaseAdapter:
"""Centralized database-specific query generation with caching"""
def __init__(self, db):
self.db = db
self.dialect = db.bind.dialect.name
self._cache: Dict[str, DatabaseType] = {}
def get_database_type(self, column_name: str = "meta") -> DatabaseType:
"""Determine database type with caching"""
cache_key = f"{self.dialect}_{column_name}"
if cache_key in self._cache:
return self._cache[cache_key]
if self.dialect == "sqlite":
result = DatabaseType.SQLITE
elif self.dialect == "postgresql":
result = DatabaseType.POSTGRESQL_JSONB if self._is_jsonb_column(column_name) else DatabaseType.POSTGRESQL_JSON
else:
result = DatabaseType.UNSUPPORTED
self._cache[cache_key] = result
return result
def _is_jsonb_column(self, column_name: str) -> bool:
"""Check if column is JSONB type"""
if JSONB is None or self.dialect != "postgresql":
return False
try:
result = self.db.execute(text("""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'chat' AND column_name = :column_name
"""), {"column_name": column_name})
row = result.fetchone()
return row[0].lower() == 'jsonb' if row else False
except Exception:
return False
def _get_function_template(self, db_type: DatabaseType, function_type: str) -> Optional[str]:
"""Get function template for specific database type and function"""
templates = {
DatabaseType.SQLITE: {
"tag_exists": "EXISTS (SELECT 1 FROM json_each({column}, '$.tags') WHERE json_each.value = :tag_id)",
"has_key": "json_extract({column}, '$.{path}') IS NOT NULL",
"array_length": "json_array_length({column}, '$.{path}')",
"array_elements": "json_each({column}, '$.{path}')",
"content_search": """EXISTS (
SELECT 1 FROM json_each({column}, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)"""
},
DatabaseType.POSTGRESQL_JSON: {
"tag_exists": "EXISTS (SELECT 1 FROM json_array_elements_text({column}->'tags') elem WHERE elem = :tag_id)",
"has_key": "{column} ? '{path}'",
"array_length": "json_array_length({column}->'{path}')",
"array_elements": "json_array_elements({column}->'{path}')",
"content_search": """EXISTS (
SELECT 1 FROM json_array_elements({column}->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)"""
},
DatabaseType.POSTGRESQL_JSONB: {
"tag_exists": "EXISTS (SELECT 1 FROM jsonb_array_elements_text({column}->'tags') elem WHERE elem = :tag_id)",
"has_key": "{column} ? '{path}'",
"array_length": "jsonb_array_length({column}->'{path}')",
"array_elements": "jsonb_array_elements({column}->'{path}')",
"content_search": """EXISTS (
SELECT 1 FROM jsonb_array_elements({column}->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)"""
}
}
return templates.get(db_type, {}).get(function_type)
def build_tag_filter(self, column_name: str, tag_ids: List[str], match_all: bool = True) -> Optional[Union[TextClause, and_, or_]]:
"""Build optimized tag filtering query"""
if not tag_ids:
return None
db_type = self.get_database_type(column_name)
template = self._get_function_template(db_type, "tag_exists")
if not template:
return None
query_template = template.replace("{column}", f"Chat.{column_name}")
if match_all:
return and_(*[
text(query_template).params(tag_id=tag_id)
for tag_id in tag_ids
])
else:
conditions = []
params = {}
for idx, tag_id in enumerate(tag_ids):
param_name = f"tag_id_{idx}"
params[param_name] = tag_id
condition_template = query_template.replace(":tag_id", f":{param_name}")
conditions.append(text(condition_template))
return or_(*conditions).params(**params)
def build_search_filter(self, search_text: str) -> Optional[TextClause]:
"""Build content search query"""
db_type = self.get_database_type("chat")
template = self._get_function_template(db_type, "content_search")
if not template:
return None
query = template.replace("{column}", "Chat.chat")
return text(query).params(search_text=search_text)
def build_untagged_filter(self, column_name: str = "meta") -> Optional[or_]:
"""Build filter for chats without tags"""
db_type = self.get_database_type(column_name)
has_key_template = self._get_function_template(db_type, "has_key")
array_length_template = self._get_function_template(db_type, "array_length")
if not has_key_template or not array_length_template:
return None
has_key = has_key_template.replace("{column}", f"Chat.{column_name}").replace("{path}", "tags")
array_length = array_length_template.replace("{column}", f"Chat.{column_name}").replace("{path}", "tags")
return or_(
text(f"NOT ({has_key})"),
text(f"{array_length} = 0")
)
####################
# Utility Functions
####################
def normalize_tag_name(tag_name: str) -> str:
"""Normalize tag name for consistent storage and querying"""
return tag_name.replace(" ", "_").lower()
def normalize_tag_names(tag_names: List[str]) -> List[str]:
"""Normalize multiple tag names"""
return [normalize_tag_name(tag) for tag in tag_names]
####################
# Chat DB Schema
@ -21,13 +186,15 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(JSON)
chat = Column(JSON) # For JSONB support, change to: Column(JSONB) if JSONB else Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@ -36,7 +203,7 @@ class Chat(Base):
archived = Column(Boolean, default=False)
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}")
meta = Column(JSON, server_default="{}") # For JSONB support, change to: Column(JSONB, server_default="{}") if JSONB else Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True)
@ -105,6 +272,337 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def __init__(self):
pass
def _get_adapter(self, db) -> DatabaseAdapter:
"""Get database adapter for the current session"""
return DatabaseAdapter(db)
# Legacy methods for backward compatibility
def _is_jsonb_column(self, db, column_name: str) -> bool:
"""Legacy method - use adapter instead"""
adapter = self._get_adapter(db)
return adapter.get_database_type(column_name) == DatabaseType.POSTGRESQL_JSONB
def _get_json_query_type(self, db, column_name: str = "meta") -> str:
"""Legacy method - use adapter instead"""
adapter = self._get_adapter(db)
db_type = adapter.get_database_type(column_name)
return db_type.value
def check_database_compatibility(self) -> dict:
"""Check database compatibility and available features"""
try:
with get_db() as db:
adapter = self._get_adapter(db)
dialect_name = db.bind.dialect.name
meta_type = adapter.get_database_type("meta")
chat_type = adapter.get_database_type("chat")
compatibility = {
"database_type": dialect_name,
"json_support": meta_type != DatabaseType.UNSUPPORTED,
"jsonb_support": meta_type == DatabaseType.POSTGRESQL_JSONB or chat_type == DatabaseType.POSTGRESQL_JSONB,
"gin_indexes_support": dialect_name == "postgresql",
"tag_filtering_support": meta_type != DatabaseType.UNSUPPORTED,
"advanced_search_support": chat_type != DatabaseType.UNSUPPORTED,
"meta_column_type": meta_type.value,
"chat_column_type": chat_type.value,
"features": [],
"limitations": [],
"recommendations": []
}
# Add features based on database type
if dialect_name == "sqlite":
compatibility["features"] = ["JSON1 extension", "Basic tag filtering", "Message search"]
compatibility["limitations"] = ["No GIN indexes", "Limited JSON optimization"]
elif dialect_name == "postgresql":
compatibility["features"] = ["Full JSON/JSONB support", "GIN indexes", "Advanced filtering"]
if compatibility["jsonb_support"]:
compatibility["features"].append("JSONB binary format optimization")
return compatibility
except Exception as e:
log.error(f"Error checking database compatibility: {e}")
return {"error": str(e), "database_type": "unknown"}
def create_gin_indexes(self) -> bool:
"""Create GIN indexes on JSONB columns for better query performance"""
try:
with get_db() as db:
adapter = self._get_adapter(db)
if db.bind.dialect.name != "postgresql":
return False
meta_type = adapter.get_database_type("meta")
chat_type = adapter.get_database_type("chat")
has_jsonb_meta = meta_type == DatabaseType.POSTGRESQL_JSONB
has_jsonb_chat = chat_type == DatabaseType.POSTGRESQL_JSONB
if not (has_jsonb_meta or has_jsonb_chat):
return False
# Create GIN indexes
if has_jsonb_meta:
try:
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_meta_gin
ON chat USING GIN (meta)
"""))
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_meta_tags_gin
ON chat USING GIN ((meta->'tags'))
"""))
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_has_tags
ON chat USING BTREE ((meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0))
WHERE meta ? 'tags'
"""))
except Exception:
pass
if has_jsonb_chat:
try:
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_chat_gin
ON chat USING GIN (chat)
"""))
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_messages_gin
ON chat USING GIN ((chat->'messages'))
"""))
except Exception:
pass
db.commit()
return True
except Exception:
return False
def check_gin_indexes(self) -> dict:
"""
Check which GIN indexes exist on the chat table.
Returns a dictionary with index names and their status.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return {"error": "GIN indexes are only supported on PostgreSQL"}
result = db.execute(text("""
SELECT indexname, indexdef
FROM pg_indexes
WHERE tablename = 'chat'
AND indexdef LIKE '%USING gin%'
"""))
indexes = {}
for row in result:
indexes[row[0]] = {
"exists": True,
"definition": row[1]
}
# Check for expected indexes
expected_indexes = [
"idx_chat_meta_gin",
"idx_chat_chat_gin",
"idx_chat_meta_tags_gin",
"idx_chat_has_tags",
"idx_chat_tag_count",
"idx_chat_json_tags",
"idx_chat_messages_gin"
]
for idx_name in expected_indexes:
if idx_name not in indexes:
indexes[idx_name] = {"exists": False}
return indexes
except Exception as e:
log.error(f"Error checking GIN indexes: {e}")
return {"error": str(e)}
def drop_gin_indexes(self) -> bool:
"""Drop all GIN indexes on the chat table"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return False
indexes_to_drop = [
"idx_chat_meta_gin",
"idx_chat_chat_gin",
"idx_chat_meta_tags_gin",
"idx_chat_has_tags",
"idx_chat_tag_count",
"idx_chat_json_tags",
"idx_chat_messages_gin"
]
for idx_name in indexes_to_drop:
try:
db.execute(text(f"DROP INDEX CONCURRENTLY IF EXISTS {idx_name}"))
except Exception:
pass
db.commit()
return True
except Exception:
return False
def create_tag_indexes(self) -> bool:
"""
Create specialized indexes optimized specifically for tag operations.
This includes both GIN and BTREE indexes for different tag query patterns.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
log.info("Tag indexes are only supported on PostgreSQL")
return False
has_jsonb_meta = self._is_jsonb_column(db, "meta")
indexes_created = []
if has_jsonb_meta:
# JSONB-specific tag indexes
tag_indexes = [
{
"name": "idx_chat_meta_tags_gin",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_meta_tags_gin ON chat USING GIN ((meta->'tags'))",
"purpose": "Fast tag containment queries (@>, ?, etc.)"
},
{
"name": "idx_chat_has_tags",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_has_tags ON chat USING BTREE ((meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0)) WHERE meta ? 'tags'",
"purpose": "Fast filtering for chats with/without tags"
},
{
"name": "idx_chat_tag_count",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_tag_count ON chat USING BTREE ((jsonb_array_length(meta->'tags'))) WHERE meta ? 'tags'",
"purpose": "Fast filtering by number of tags"
},
{
"name": "idx_chat_specific_tags",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_specific_tags ON chat USING GIN ((meta->'tags')) WHERE jsonb_array_length(meta->'tags') > 0",
"purpose": "Optimized for chats that actually have tags"
}
]
else:
# JSON-specific tag indexes (less optimal but still helpful)
tag_indexes = [
{
"name": "idx_chat_json_tags",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_json_tags ON chat USING GIN ((meta->'tags')) WHERE meta ? 'tags'",
"purpose": "Tag queries for JSON columns"
}
]
for index_info in tag_indexes:
try:
db.execute(text(index_info["sql"]))
indexes_created.append(f"{index_info['name']} ({index_info['purpose']})")
log.info(f"Created tag index: {index_info['name']}")
except Exception as e:
log.warning(f"Failed to create {index_info['name']}: {e}")
db.commit()
if indexes_created:
log.info(f"Successfully created tag indexes: {len(indexes_created)} indexes")
for idx in indexes_created:
log.info(f"{idx}")
else:
log.info("No tag indexes were created")
return True
except Exception as e:
log.error(f"Error creating tag indexes: {e}")
return False
def optimize_tag_queries(self) -> dict:
"""
Analyze and provide recommendations for tag query optimization.
Returns statistics and suggestions for improving tag query performance.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return {"error": "Tag optimization is only supported on PostgreSQL"}
stats = {}
# Get basic tag statistics
result = db.execute(text("""
SELECT
COUNT(*) as total_chats,
COUNT(*) FILTER (WHERE meta ? 'tags') as chats_with_tags,
COUNT(*) FILTER (WHERE meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0) as chats_with_actual_tags,
AVG(CASE WHEN meta ? 'tags' THEN jsonb_array_length(meta->'tags') ELSE 0 END) as avg_tags_per_chat
FROM chat
"""))
row = result.fetchone()
if row:
stats.update({
"total_chats": row[0],
"chats_with_tags": row[1],
"chats_with_actual_tags": row[2],
"avg_tags_per_chat": float(row[3]) if row[3] else 0
})
# Get most common tags
result = db.execute(text("""
SELECT tag_value, COUNT(*) as usage_count
FROM chat, jsonb_array_elements_text(meta->'tags') as tag_value
WHERE meta ? 'tags'
GROUP BY tag_value
ORDER BY usage_count DESC
LIMIT 10
"""))
stats["top_tags"] = [{"tag": row[0], "count": row[1]} for row in result]
# Check index usage
indexes = self.check_gin_indexes()
tag_indexes = {k: v for k, v in indexes.items() if "tag" in k.lower()}
stats["tag_indexes"] = tag_indexes
# Provide recommendations
recommendations = []
if stats["chats_with_actual_tags"] > 1000:
recommendations.append("Consider creating tag-specific indexes for better performance")
if stats["avg_tags_per_chat"] > 5:
recommendations.append("High tag usage detected - GIN indexes will provide significant benefits")
tag_coverage = stats["chats_with_actual_tags"] / stats["total_chats"] if stats["total_chats"] > 0 else 0
if tag_coverage < 0.1:
recommendations.append("Low tag usage - consider partial indexes with WHERE clauses")
stats["recommendations"] = recommendations
stats["tag_coverage_percentage"] = round(tag_coverage * 100, 2)
return stats
except Exception as e:
log.error(f"Error analyzing tag queries: {e}")
return {"error": str(e)}
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
@ -537,8 +1035,10 @@ class ChatTable:
with get_db() as db:
all_chats = (
db.query(Chat)
# .limit(limit).offset(skip)
.order_by(Chat.updated_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
@ -591,7 +1091,7 @@ class ChatTable:
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
tag_ids = [
word.replace("tag:", "").replace(" ", "_").lower()
normalize_tag_name(word.replace("tag:", ""))
for word in search_text_words
if word.startswith("tag:")
]
@ -610,115 +1110,33 @@ class ChatTable:
query = query.order_by(Chat.updated_at.desc())
# Check if the database dialect is either 'sqlite' or 'postgresql'
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
# Use adapter for cleaner query building
adapter = self._get_adapter(db)
# Add search filter if search text provided
if search_text:
search_filter = adapter.build_search_filter(search_text)
if search_filter is not None:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
)
"""
)
Chat.title.ilike(f"%{search_text}%") | search_filter
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
WHERE tag.value = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
)
"""
)
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
WHERE tag = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
else:
# Fallback to title-only search for unsupported databases
query = query.filter(Chat.title.ilike(f"%{search_text}%"))
# Add tag filters
if "none" in tag_ids:
untagged_filter = adapter.build_untagged_filter("meta")
if untagged_filter is not None:
query = query.filter(untagged_filter)
elif tag_ids:
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
log.info(f"The number of chats: {len(all_chats)}")
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
@ -775,32 +1193,59 @@ class ChatTable:
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
tag_id = normalize_tag_name(tag_name)
log.info(f"DB dialect name: {db.bind.dialect.name}")
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
return []
all_chats = query.all()
log.debug(f"all_chats: {all_chats}")
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_multiple_tags(
self, user_id: str, tag_names: List[str], match_all: bool = True, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
"""Get chats that match multiple tags"""
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
if not tag_names:
return []
# Normalize tag names
tag_ids = normalize_tag_names(tag_names)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=match_all)
if tag_filter is not None:
query = query.filter(tag_filter)
# Apply pagination and ordering
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def get_chats_without_tags(self, user_id: str, skip: int = 0, limit: int = 50) -> list[ChatModel]:
"""Get chats that have no tags"""
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Use adapter to build untagged filter
untagged_filter = adapter.build_untagged_filter("meta")
if untagged_filter is not None:
query = query.filter(untagged_filter)
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
@ -826,40 +1271,20 @@ class ChatTable:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
tag_id = normalize_tag_name(tag_name)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
return query.count()
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Get the count of matching records
count = query.count()
# Debugging output for inspection
log.info(f"Count of chats for tag '{tag_name}': {count}")
return count
return 0
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
@ -868,7 +1293,7 @@ class ChatTable:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower()
tag_id = normalize_tag_name(tag_name)
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {