open-webui/backend/open_webui/models/chats.py
PVBLIC Foundation 484133de4c
Update chats.py
Add JSONB support to Postgres
2025-05-27 17:43:04 -07:00

1381 lines
50 KiB
Python

import logging
import json
import time
import uuid
from typing import Optional, Dict, Any, List, Union
from enum import Enum
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
from sqlalchemy.sql.elements import TextClause
# Import JSONB for PostgreSQL support
try:
from sqlalchemy.dialects.postgresql import JSONB
except ImportError:
JSONB = None
####################
# Database Adapter
####################
class DatabaseType(Enum):
SQLITE = "sqlite"
POSTGRESQL_JSON = "postgresql_json"
POSTGRESQL_JSONB = "postgresql_jsonb"
UNSUPPORTED = "unsupported"
class DatabaseAdapter:
"""Centralized database-specific query generation with caching"""
def __init__(self, db):
self.db = db
self.dialect = db.bind.dialect.name
self._cache: Dict[str, DatabaseType] = {}
def get_database_type(self, column_name: str = "meta") -> DatabaseType:
"""Determine database type with caching"""
cache_key = f"{self.dialect}_{column_name}"
if cache_key in self._cache:
return self._cache[cache_key]
if self.dialect == "sqlite":
result = DatabaseType.SQLITE
elif self.dialect == "postgresql":
result = DatabaseType.POSTGRESQL_JSONB if self._is_jsonb_column(column_name) else DatabaseType.POSTGRESQL_JSON
else:
result = DatabaseType.UNSUPPORTED
self._cache[cache_key] = result
return result
def _is_jsonb_column(self, column_name: str) -> bool:
"""Check if column is JSONB type"""
if JSONB is None or self.dialect != "postgresql":
return False
try:
result = self.db.execute(text("""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'chat' AND column_name = :column_name
"""), {"column_name": column_name})
row = result.fetchone()
return row[0].lower() == 'jsonb' if row else False
except Exception:
return False
def _get_function_template(self, db_type: DatabaseType, function_type: str) -> Optional[str]:
"""Get function template for specific database type and function"""
templates = {
DatabaseType.SQLITE: {
"tag_exists": "EXISTS (SELECT 1 FROM json_each({column}, '$.tags') WHERE json_each.value = :tag_id)",
"has_key": "json_extract({column}, '$.{path}') IS NOT NULL",
"array_length": "json_array_length({column}, '$.{path}')",
"array_elements": "json_each({column}, '$.{path}')",
"content_search": """EXISTS (
SELECT 1 FROM json_each({column}, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)"""
},
DatabaseType.POSTGRESQL_JSON: {
"tag_exists": "EXISTS (SELECT 1 FROM json_array_elements_text({column}->'tags') elem WHERE elem = :tag_id)",
"has_key": "{column} ? '{path}'",
"array_length": "json_array_length({column}->'{path}')",
"array_elements": "json_array_elements({column}->'{path}')",
"content_search": """EXISTS (
SELECT 1 FROM json_array_elements({column}->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)"""
},
DatabaseType.POSTGRESQL_JSONB: {
"tag_exists": "EXISTS (SELECT 1 FROM jsonb_array_elements_text({column}->'tags') elem WHERE elem = :tag_id)",
"has_key": "{column} ? '{path}'",
"array_length": "jsonb_array_length({column}->'{path}')",
"array_elements": "jsonb_array_elements({column}->'{path}')",
"content_search": """EXISTS (
SELECT 1 FROM jsonb_array_elements({column}->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)"""
}
}
return templates.get(db_type, {}).get(function_type)
def build_tag_filter(self, column_name: str, tag_ids: List[str], match_all: bool = True) -> Optional[Union[TextClause, and_, or_]]:
"""Build optimized tag filtering query"""
if not tag_ids:
return None
db_type = self.get_database_type(column_name)
template = self._get_function_template(db_type, "tag_exists")
if not template:
return None
query_template = template.replace("{column}", f"Chat.{column_name}")
if match_all:
return and_(*[
text(query_template).params(tag_id=tag_id)
for tag_id in tag_ids
])
else:
conditions = []
params = {}
for idx, tag_id in enumerate(tag_ids):
param_name = f"tag_id_{idx}"
params[param_name] = tag_id
condition_template = query_template.replace(":tag_id", f":{param_name}")
conditions.append(text(condition_template))
return or_(*conditions).params(**params)
def build_search_filter(self, search_text: str) -> Optional[TextClause]:
"""Build content search query"""
db_type = self.get_database_type("chat")
template = self._get_function_template(db_type, "content_search")
if not template:
return None
query = template.replace("{column}", "Chat.chat")
return text(query).params(search_text=search_text)
def build_untagged_filter(self, column_name: str = "meta") -> Optional[or_]:
"""Build filter for chats without tags"""
db_type = self.get_database_type(column_name)
has_key_template = self._get_function_template(db_type, "has_key")
array_length_template = self._get_function_template(db_type, "array_length")
if not has_key_template or not array_length_template:
return None
has_key = has_key_template.replace("{column}", f"Chat.{column_name}").replace("{path}", "tags")
array_length = array_length_template.replace("{column}", f"Chat.{column_name}").replace("{path}", "tags")
return or_(
text(f"NOT ({has_key})"),
text(f"{array_length} = 0")
)
####################
# Utility Functions
####################
def normalize_tag_name(tag_name: str) -> str:
"""Normalize tag name for consistent storage and querying"""
return tag_name.replace(" ", "_").lower()
def normalize_tag_names(tag_names: List[str]) -> List[str]:
"""Normalize multiple tag names"""
return [normalize_tag_name(tag) for tag in tag_names]
####################
# Chat DB Schema
####################
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(JSON) # For JSONB support, change to: Column(JSONB) if JSONB else Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}") # For JSONB support, change to: Column(JSONB, server_default="{}") if JSONB else Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True)
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
title: str
chat: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
share_id: Optional[str] = None
archived: bool = False
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
####################
# Forms
####################
class ChatForm(BaseModel):
chat: dict
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
folder_id: Optional[str] = None
class ChatTitleMessagesForm(BaseModel):
title: str
messages: list[dict]
class ChatTitleForm(BaseModel):
title: str
class ChatResponse(BaseModel):
id: str
user_id: str
title: str
chat: dict
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
archived: bool
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
class ChatTitleIdResponse(BaseModel):
id: str
title: str
updated_at: int
created_at: int
class ChatTable:
def __init__(self):
pass
def _get_adapter(self, db) -> DatabaseAdapter:
"""Get database adapter for the current session"""
return DatabaseAdapter(db)
# Legacy methods for backward compatibility
def _is_jsonb_column(self, db, column_name: str) -> bool:
"""Legacy method - use adapter instead"""
adapter = self._get_adapter(db)
return adapter.get_database_type(column_name) == DatabaseType.POSTGRESQL_JSONB
def _get_json_query_type(self, db, column_name: str = "meta") -> str:
"""Legacy method - use adapter instead"""
adapter = self._get_adapter(db)
db_type = adapter.get_database_type(column_name)
return db_type.value
def check_database_compatibility(self) -> dict:
"""Check database compatibility and available features"""
try:
with get_db() as db:
adapter = self._get_adapter(db)
dialect_name = db.bind.dialect.name
meta_type = adapter.get_database_type("meta")
chat_type = adapter.get_database_type("chat")
compatibility = {
"database_type": dialect_name,
"json_support": meta_type != DatabaseType.UNSUPPORTED,
"jsonb_support": meta_type == DatabaseType.POSTGRESQL_JSONB or chat_type == DatabaseType.POSTGRESQL_JSONB,
"gin_indexes_support": dialect_name == "postgresql",
"tag_filtering_support": meta_type != DatabaseType.UNSUPPORTED,
"advanced_search_support": chat_type != DatabaseType.UNSUPPORTED,
"meta_column_type": meta_type.value,
"chat_column_type": chat_type.value,
"features": [],
"limitations": [],
"recommendations": []
}
# Add features based on database type
if dialect_name == "sqlite":
compatibility["features"] = ["JSON1 extension", "Basic tag filtering", "Message search"]
compatibility["limitations"] = ["No GIN indexes", "Limited JSON optimization"]
elif dialect_name == "postgresql":
compatibility["features"] = ["Full JSON/JSONB support", "GIN indexes", "Advanced filtering"]
if compatibility["jsonb_support"]:
compatibility["features"].append("JSONB binary format optimization")
return compatibility
except Exception as e:
log.error(f"Error checking database compatibility: {e}")
return {"error": str(e), "database_type": "unknown"}
def create_gin_indexes(self) -> bool:
"""Create GIN indexes on JSONB columns for better query performance"""
try:
with get_db() as db:
adapter = self._get_adapter(db)
if db.bind.dialect.name != "postgresql":
return False
meta_type = adapter.get_database_type("meta")
chat_type = adapter.get_database_type("chat")
has_jsonb_meta = meta_type == DatabaseType.POSTGRESQL_JSONB
has_jsonb_chat = chat_type == DatabaseType.POSTGRESQL_JSONB
if not (has_jsonb_meta or has_jsonb_chat):
return False
# Create GIN indexes
if has_jsonb_meta:
try:
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_meta_gin
ON chat USING GIN (meta)
"""))
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_meta_tags_gin
ON chat USING GIN ((meta->'tags'))
"""))
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_has_tags
ON chat USING BTREE ((meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0))
WHERE meta ? 'tags'
"""))
except Exception:
pass
if has_jsonb_chat:
try:
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_chat_gin
ON chat USING GIN (chat)
"""))
db.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_messages_gin
ON chat USING GIN ((chat->'messages'))
"""))
except Exception:
pass
db.commit()
return True
except Exception:
return False
def check_gin_indexes(self) -> dict:
"""
Check which GIN indexes exist on the chat table.
Returns a dictionary with index names and their status.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return {"error": "GIN indexes are only supported on PostgreSQL"}
result = db.execute(text("""
SELECT indexname, indexdef
FROM pg_indexes
WHERE tablename = 'chat'
AND indexdef LIKE '%USING gin%'
"""))
indexes = {}
for row in result:
indexes[row[0]] = {
"exists": True,
"definition": row[1]
}
# Check for expected indexes
expected_indexes = [
"idx_chat_meta_gin",
"idx_chat_chat_gin",
"idx_chat_meta_tags_gin",
"idx_chat_has_tags",
"idx_chat_tag_count",
"idx_chat_json_tags",
"idx_chat_messages_gin"
]
for idx_name in expected_indexes:
if idx_name not in indexes:
indexes[idx_name] = {"exists": False}
return indexes
except Exception as e:
log.error(f"Error checking GIN indexes: {e}")
return {"error": str(e)}
def drop_gin_indexes(self) -> bool:
"""Drop all GIN indexes on the chat table"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return False
indexes_to_drop = [
"idx_chat_meta_gin",
"idx_chat_chat_gin",
"idx_chat_meta_tags_gin",
"idx_chat_has_tags",
"idx_chat_tag_count",
"idx_chat_json_tags",
"idx_chat_messages_gin"
]
for idx_name in indexes_to_drop:
try:
db.execute(text(f"DROP INDEX CONCURRENTLY IF EXISTS {idx_name}"))
except Exception:
pass
db.commit()
return True
except Exception:
return False
def create_tag_indexes(self) -> bool:
"""
Create specialized indexes optimized specifically for tag operations.
This includes both GIN and BTREE indexes for different tag query patterns.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
log.info("Tag indexes are only supported on PostgreSQL")
return False
has_jsonb_meta = self._is_jsonb_column(db, "meta")
indexes_created = []
if has_jsonb_meta:
# JSONB-specific tag indexes
tag_indexes = [
{
"name": "idx_chat_meta_tags_gin",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_meta_tags_gin ON chat USING GIN ((meta->'tags'))",
"purpose": "Fast tag containment queries (@>, ?, etc.)"
},
{
"name": "idx_chat_has_tags",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_has_tags ON chat USING BTREE ((meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0)) WHERE meta ? 'tags'",
"purpose": "Fast filtering for chats with/without tags"
},
{
"name": "idx_chat_tag_count",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_tag_count ON chat USING BTREE ((jsonb_array_length(meta->'tags'))) WHERE meta ? 'tags'",
"purpose": "Fast filtering by number of tags"
},
{
"name": "idx_chat_specific_tags",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_specific_tags ON chat USING GIN ((meta->'tags')) WHERE jsonb_array_length(meta->'tags') > 0",
"purpose": "Optimized for chats that actually have tags"
}
]
else:
# JSON-specific tag indexes (less optimal but still helpful)
tag_indexes = [
{
"name": "idx_chat_json_tags",
"sql": "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_json_tags ON chat USING GIN ((meta->'tags')) WHERE meta ? 'tags'",
"purpose": "Tag queries for JSON columns"
}
]
for index_info in tag_indexes:
try:
db.execute(text(index_info["sql"]))
indexes_created.append(f"{index_info['name']} ({index_info['purpose']})")
log.info(f"Created tag index: {index_info['name']}")
except Exception as e:
log.warning(f"Failed to create {index_info['name']}: {e}")
db.commit()
if indexes_created:
log.info(f"Successfully created tag indexes: {len(indexes_created)} indexes")
for idx in indexes_created:
log.info(f"{idx}")
else:
log.info("No tag indexes were created")
return True
except Exception as e:
log.error(f"Error creating tag indexes: {e}")
return False
def optimize_tag_queries(self) -> dict:
"""
Analyze and provide recommendations for tag query optimization.
Returns statistics and suggestions for improving tag query performance.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return {"error": "Tag optimization is only supported on PostgreSQL"}
stats = {}
# Get basic tag statistics
result = db.execute(text("""
SELECT
COUNT(*) as total_chats,
COUNT(*) FILTER (WHERE meta ? 'tags') as chats_with_tags,
COUNT(*) FILTER (WHERE meta ? 'tags' AND jsonb_array_length(meta->'tags') > 0) as chats_with_actual_tags,
AVG(CASE WHEN meta ? 'tags' THEN jsonb_array_length(meta->'tags') ELSE 0 END) as avg_tags_per_chat
FROM chat
"""))
row = result.fetchone()
if row:
stats.update({
"total_chats": row[0],
"chats_with_tags": row[1],
"chats_with_actual_tags": row[2],
"avg_tags_per_chat": float(row[3]) if row[3] else 0
})
# Get most common tags
result = db.execute(text("""
SELECT tag_value, COUNT(*) as usage_count
FROM chat, jsonb_array_elements_text(meta->'tags') as tag_value
WHERE meta ? 'tags'
GROUP BY tag_value
ORDER BY usage_count DESC
LIMIT 10
"""))
stats["top_tags"] = [{"tag": row[0], "count": row[1]} for row in result]
# Check index usage
indexes = self.check_gin_indexes()
tag_indexes = {k: v for k, v in indexes.items() if "tag" in k.lower()}
stats["tag_indexes"] = tag_indexes
# Provide recommendations
recommendations = []
if stats["chats_with_actual_tags"] > 1000:
recommendations.append("Consider creating tag-specific indexes for better performance")
if stats["avg_tags_per_chat"] > 5:
recommendations.append("High tag usage detected - GIN indexes will provide significant benefits")
tag_coverage = stats["chats_with_actual_tags"] / stats["total_chats"] if stats["total_chats"] > 0 else 0
if tag_coverage < 0.1:
recommendations.append("Low tag usage - consider partial indexes with WHERE clauses")
stats["recommendations"] = recommendations
stats["tag_coverage_percentage"] = round(tag_coverage * 100, 2)
return stats
except Exception as e:
log.error(f"Error analyzing tag queries: {e}")
return {"error": str(e)}
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def import_chat(
self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
with get_db() as db:
chat_item = db.get(Chat, id)
chat_item.chat = chat
chat_item.title = chat["title"] if "title" in chat else "New Chat"
chat_item.updated_at = int(time.time())
db.commit()
db.refresh(chat_item)
return ChatModel.model_validate(chat_item)
except Exception:
return None
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
chat["title"] = title
return self.update_chat_by_id(id, chat)
def update_chat_tags_by_id(
self, id: str, tags: list[str], user
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
self.delete_all_tags_by_id_and_user_id(id, user.id)
for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
for tag_name in tags:
if tag_name.lower() == "none":
continue
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
return self.get_chat_by_id(id)
def get_chat_title_by_id(self, id: str) -> Optional[str]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("title", "New Chat")
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}) or {}
def get_message_by_id_and_message_id(
self, id: str, message_id: str
) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}).get(message_id, {})
def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
history["messages"][message_id] = {
**history["messages"][message_id],
**message,
}
else:
history["messages"][message_id] = message
history["currentId"] = message_id
chat["history"] = history
return self.update_chat_by_id(id, chat)
def add_message_status_to_chat_by_id_and_message_id(
self, id: str, message_id: str, status: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
status_history = history["messages"][message_id].get("statusHistory", [])
status_history.append(status)
history["messages"][message_id]["statusHistory"] = status_history
chat["history"] = history
return self.update_chat_by_id(id, chat)
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share
chat = db.get(Chat, chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"created_at": chat.created_at,
"updated_at": int(time.time()),
}
)
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)
db.commit()
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, chat_id)
shared_chat = (
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
)
if shared_chat is None:
return self.insert_shared_chat_by_chat_id(chat_id)
shared_chat.title = chat.title
shared_chat.chat = chat.chat
shared_chat.updated_at = int(time.time())
db.commit()
db.refresh(shared_chat)
return ChatModel.model_validate(shared_chat)
except Exception:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit()
return True
except Exception:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.share_id = share_id
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.commit()
return True
except Exception:
return False
def get_archived_chat_list_by_user_id(
self,
user_id: str,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id, archived=True)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived:
query = query.filter_by(archived=False)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_title_id_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived:
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc()).with_entities(
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
)
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
# result has to be destructured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
return [
ChatTitleIdResponse.model_validate(
{
"id": chat[0],
"title": chat[1],
"updated_at": chat[2],
"created_at": chat[3],
}
)
for chat in all_chats
]
def get_chat_list_by_chat_ids(
self, chat_ids: list[str], skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter(Chat.id.in_(chat_ids))
.filter_by(archived=False)
.order_by(Chat.updated_at.desc())
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
return self.get_chat_by_id(id)
else:
return None
except Exception:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.order_by(Chat.updated_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text(
self,
user_id: str,
search_text: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(
user_id, include_archived, filter={}, skip=skip, limit=limit
)
search_text_words = search_text.split(" ")
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
tag_ids = [
normalize_tag_name(word.replace("tag:", ""))
for word in search_text_words
if word.startswith("tag:")
]
search_text_words = [
word for word in search_text_words if not word.startswith("tag:")
]
search_text = " ".join(search_text_words)
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived:
query = query.filter(Chat.archived == False)
query = query.order_by(Chat.updated_at.desc())
# Use adapter for cleaner query building
adapter = self._get_adapter(db)
# Add search filter if search text provided
if search_text:
search_filter = adapter.build_search_filter(search_text)
if search_filter is not None:
query = query.filter(
Chat.title.ilike(f"%{search_text}%") | search_filter
)
else:
# Fallback to title-only search for unsupported databases
query = query.filter(Chat.title.ilike(f"%{search_text}%"))
# Add tag filters
if "none" in tag_ids:
untagged_filter = adapter.build_untagged_filter("meta")
if untagged_filter is not None:
query = query.filter(untagged_filter)
elif tag_ids:
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter(
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.folder_id = folder_id
chat.updated_at = int(time.time())
chat.pinned = False
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = normalize_tag_name(tag_name)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def get_chats_by_multiple_tags(
self, user_id: str, tag_names: List[str], match_all: bool = True, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
"""Get chats that match multiple tags"""
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
if not tag_names:
return []
# Normalize tag names
tag_ids = normalize_tag_names(tag_names)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=match_all)
if tag_filter is not None:
query = query.filter(tag_filter)
# Apply pagination and ordering
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def get_chats_without_tags(self, user_id: str, skip: int = 0, limit: int = 50) -> list[ChatModel]:
"""Get chats that have no tags"""
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Use adapter to build untagged filter
untagged_filter = adapter.build_untagged_filter("meta")
if untagged_filter is not None:
query = query.filter(untagged_filter)
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
try:
with get_db() as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = normalize_tag_name(tag_name)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
return query.count()
else:
return 0
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = normalize_tag_name(tag_name)
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": list(set(tags)),
}
db.commit()
return True
except Exception:
return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.meta = {
**chat.meta,
"tags": [],
}
db.commit()
return True
except Exception:
return False
def delete_chat_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(id=id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except Exception:
return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except Exception:
return False
def delete_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
self.delete_shared_chats_by_user_id(user_id)
db.query(Chat).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception:
return False
def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
db.commit()
return True
except Exception:
return False
Chats = ChatTable()