open-webui/backend/open_webui/models/chats.py
PVBLIC Foundation af97737b0c
Update chats.py
Change to pass CI
2025-05-30 20:21:21 -07:00

1641 lines
59 KiB
Python

import logging
import json
import time
import uuid
from typing import Optional, Dict, Any, List, Union
from enum import Enum
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
from sqlalchemy.sql.elements import TextClause
# Import JSONB for PostgreSQL support - gracefully handle absence in other databases
try:
from sqlalchemy.dialects.postgresql import JSONB
except ImportError:
JSONB = None
####################
# Database Adapter
####################
class DatabaseType(Enum):
"""
Enumeration of supported database types and their JSON capabilities.
Used by DatabaseAdapter to determine optimal query strategies and
features available for each database backend.
"""
SQLITE = "sqlite" # SQLite with JSON1 extension
POSTGRESQL_JSON = "postgresql_json" # PostgreSQL with standard JSON
POSTGRESQL_JSONB = "postgresql_jsonb" # PostgreSQL with binary JSONB
UNSUPPORTED = "unsupported" # Unsupported database type
class DatabaseAdapter:
"""Centralized database-specific query generation with caching"""
def __init__(self, db):
"""
Initialize adapter with database session and setup caching.
Args:
db: SQLAlchemy database session for introspection and queries
"""
self.db = db
self.dialect = db.bind.dialect.name
# Cache database types to avoid repeated introspection queries
self._cache: Dict[str, DatabaseType] = {}
def get_database_type(self, column_name: str = "meta") -> DatabaseType:
"""Determine database type with caching"""
cache_key = f"{self.dialect}_{column_name}"
if cache_key in self._cache:
return self._cache[cache_key]
# Determine database type based on dialect and column inspection
if self.dialect == "sqlite":
result = DatabaseType.SQLITE
elif self.dialect == "postgresql":
# Check if column uses JSONB or standard JSON
result = (
DatabaseType.POSTGRESQL_JSONB
if self._is_jsonb_column(column_name)
else DatabaseType.POSTGRESQL_JSON
)
else:
result = DatabaseType.UNSUPPORTED
# Cache result for future queries
self._cache[cache_key] = result
return result
def _is_jsonb_column(self, column_name: str) -> bool:
"""Check if column is JSONB type"""
if JSONB is None or self.dialect != "postgresql":
return False
try:
# Query PostgreSQL system catalog for column data type
result = self.db.execute(
text(
"""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'chat' AND column_name = :column_name
"""
),
{"column_name": column_name},
)
row = result.fetchone()
return row[0].lower() == "jsonb" if row else False
except Exception:
# Gracefully handle permission or connection issues
return False
def _get_function_template(
self, db_type: DatabaseType, function_type: str
) -> Optional[str]:
"""Get function template for specific database type and function"""
templates = {
DatabaseType.SQLITE: {
# SQLite with JSON1 extension functions
"tag_exists": (
"EXISTS (SELECT 1 FROM json_each({column}, '$.tags') "
"WHERE json_each.value = :tag_id)"
),
"has_key": "json_extract({column}, '$.{path}') IS NOT NULL",
"array_length": "json_array_length({column}, '$.{path}')",
"array_elements": "json_each({column}, '$.{path}')",
"content_search": """EXISTS (
SELECT 1 FROM json_each({column}, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)""",
},
DatabaseType.POSTGRESQL_JSON: {
# PostgreSQL with standard JSON operators
"tag_exists": (
"EXISTS (SELECT 1 FROM json_array_elements_text({column}->'tags') "
"elem WHERE elem = :tag_id)"
),
"has_key": "{column} ? '{path}'",
"array_length": "json_array_length({column}->'{path}')",
"array_elements": "json_array_elements({column}->'{path}')",
"content_search": """EXISTS (
SELECT 1 FROM json_array_elements({column}->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)""",
},
DatabaseType.POSTGRESQL_JSONB: {
# PostgreSQL with optimized JSONB binary format
"tag_exists": (
"EXISTS (SELECT 1 FROM jsonb_array_elements_text({column}->'tags') "
"elem WHERE elem = :tag_id)"
),
"has_key": "{column} ? '{path}'",
"array_length": "jsonb_array_length({column}->'{path}')",
"array_elements": "jsonb_array_elements({column}->'{path}')",
"content_search": """EXISTS (
SELECT 1 FROM jsonb_array_elements({column}->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)""",
},
}
return templates.get(db_type, {}).get(function_type)
def build_tag_filter(
self, column_name: str, tag_ids: List[str], match_all: bool = True
) -> Optional[Union[TextClause, and_, or_]]:
"""Build optimized tag filtering query"""
if not tag_ids:
return None
# Get database-specific template for tag existence checks
db_type = self.get_database_type(column_name)
template = self._get_function_template(db_type, "tag_exists")
if not template:
return None
# Replace column placeholder with actual column reference
query_template = template.replace("{column}", f"Chat.{column_name}")
if match_all:
# AND logic: all tags must be present
return and_(
*[text(query_template).params(tag_id=tag_id) for tag_id in tag_ids]
)
else:
# OR logic: any tag can be present
# Use separate parameters to avoid conflicts
conditions = []
params = {}
for idx, tag_id in enumerate(tag_ids):
param_name = f"tag_id_{idx}"
params[param_name] = tag_id
condition_template = query_template.replace(":tag_id", f":{param_name}")
conditions.append(text(condition_template))
return or_(*conditions).params(**params)
def build_search_filter(self, search_text: str) -> Optional[TextClause]:
"""Build content search query"""
db_type = self.get_database_type("chat")
template = self._get_function_template(db_type, "content_search")
if not template:
return None
# Replace column placeholder and bind search parameter
query = template.replace("{column}", "Chat.chat")
return text(query).params(search_text=search_text)
def build_untagged_filter(self, column_name: str = "meta") -> Optional[or_]:
"""Build filter for chats without tags"""
db_type = self.get_database_type(column_name)
# Get templates for key existence and array length checks
has_key_template = self._get_function_template(db_type, "has_key")
array_length_template = self._get_function_template(db_type, "array_length")
if not has_key_template or not array_length_template:
return None
# Build conditions for missing key and empty array
has_key = has_key_template.replace("{column}", f"Chat.{column_name}").replace(
"{path}", "tags"
)
array_length = array_length_template.replace(
"{column}", f"Chat.{column_name}"
).replace("{path}", "tags")
# Return OR condition: no tags key OR empty tags array
return or_(text(f"NOT ({has_key})"), text(f"{array_length} = 0"))
####################
# Utility Functions
####################
def normalize_tag_name(tag_name: str) -> str:
"""Normalize tag name for consistent storage and querying"""
return tag_name.replace(" ", "_").lower()
def normalize_tag_names(tag_names: List[str]) -> List[str]:
"""Normalize multiple tag names"""
return [normalize_tag_name(tag) for tag in tag_names]
####################
# Chat DB Schema
####################
# Initialize logging with appropriate level for model operations
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
# Primary identifier - UUID for global uniqueness
id = Column(String, primary_key=True)
# User ownership - references user system
user_id = Column(String)
# Human-readable chat title for UI display
title = Column(Text)
# Full chat data including messages, history, and configuration
chat = Column(
JSON
) # For JSONB support, change to: Column(JSONB) if JSONB else Column(JSON)
# Timestamp fields (Unix epoch for efficiency and timezone independence)
created_at = Column(BigInteger) # Creation timestamp
updated_at = Column(BigInteger) # Last modification timestamp
# Sharing functionality - UUID for shared public links
share_id = Column(Text, unique=True, nullable=True)
# Organization and state flags
archived = Column(Boolean, default=False) # Hidden from main view
pinned = Column(Boolean, default=False, nullable=True) # Pinned to top
# Extensible metadata storage (tags, custom fields, etc.)
# For JSONB support, change to:
# Column(JSONB, server_default="{}") if JSONB else Column(JSON, server_default="{}")
meta = Column(JSON, server_default="{}")
# Optional folder organization
folder_id = Column(Text, nullable=True)
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
# Core identification and ownership
id: str
user_id: str
title: str
chat: dict
# Timestamps (Unix epoch seconds)
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
# Optional fields with defaults
share_id: Optional[str] = None # Public sharing identifier
archived: bool = False # Archive status
pinned: Optional[bool] = False # Pin status (nullable for compatibility)
# Extensible fields
meta: dict = {} # Metadata including tags
folder_id: Optional[str] = None # Folder organization
####################
# Forms
####################
class ChatForm(BaseModel):
"""
Form model for creating new chats with basic data validation.
Used for API endpoints that create new chat conversations.
Validates that the required chat data structure is present.
"""
chat: dict
class ChatImportForm(ChatForm):
"""
Extended form model for importing chats with additional metadata.
Used for bulk import operations or when creating chats with
pre-existing metadata, tags, or organizational settings.
Extends ChatForm with optional metadata fields that are not
required for basic chat creation but useful for import scenarios.
"""
meta: Optional[dict] = {} # Tags and other metadata
pinned: Optional[bool] = False # Pin status
folder_id: Optional[str] = None # Folder assignment
class ChatTitleMessagesForm(BaseModel):
"""
Form model for operations that need title and messages separately.
Used by endpoints that work with chat titles and message lists
independently, such as chat generation or title updates.
"""
title: str
messages: list[dict]
class ChatTitleForm(BaseModel):
"""
Simple form model for chat title updates.
Used by endpoints that only modify the chat title without
affecting the conversation content or metadata.
"""
title: str
class ChatResponse(BaseModel):
"""
Complete response model for chat API endpoints.
Provides a comprehensive view of chat data for API responses,
including all fields that clients might need for display,
organization, and functionality.
Used by endpoints that return full chat information to ensure
consistent response structure across the API.
"""
# Core chat identification and data
id: str
user_id: str
title: str
chat: dict
# Timestamps for sorting and display
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
# Sharing and organization
share_id: Optional[str] = None # id of the chat to be shared
archived: bool # archive status
pinned: Optional[bool] = False # pin status
# Extensible metadata and organization
meta: dict = {} # tags and other metadata
folder_id: Optional[str] = None # folder assignment
class ChatTitleIdResponse(BaseModel):
"""
Lightweight response model for chat list operations.
Optimized for chat list views that only need basic identification
and sorting information. Reduces payload size for better performance
when displaying large numbers of chats.
Used by endpoints that return chat lists, search results, or
navigation menus where full chat content is not needed.
"""
id: str
title: str
updated_at: int
created_at: int
class ChatTable:
def __init__(self):
"""
Initialize the ChatTable with default configuration.
No initialization parameters required as the class uses
dependency injection through the get_db() context manager
for database sessions.
"""
pass
def _get_adapter(self, db) -> DatabaseAdapter:
"""Get database adapter for the current session"""
return DatabaseAdapter(db)
# Legacy methods for backward compatibility
def _is_jsonb_column(self, db, column_name: str) -> bool:
"""Legacy method - use adapter instead"""
adapter = self._get_adapter(db)
return adapter.get_database_type(column_name) == DatabaseType.POSTGRESQL_JSONB
def _get_json_query_type(self, db, column_name: str = "meta") -> str:
"""Legacy method - use adapter instead"""
adapter = self._get_adapter(db)
db_type = adapter.get_database_type(column_name)
return db_type.value
def check_database_compatibility(self) -> dict:
"""Check database compatibility and available features"""
try:
with get_db() as db:
adapter = self._get_adapter(db)
dialect_name = db.bind.dialect.name
meta_type = adapter.get_database_type("meta")
chat_type = adapter.get_database_type("chat")
# Build comprehensive compatibility report
compatibility = {
"database_type": dialect_name,
"json_support": meta_type != DatabaseType.UNSUPPORTED,
"jsonb_support": (
meta_type == DatabaseType.POSTGRESQL_JSONB
or chat_type == DatabaseType.POSTGRESQL_JSONB
),
"gin_indexes_support": dialect_name == "postgresql",
"tag_filtering_support": meta_type != DatabaseType.UNSUPPORTED,
"advanced_search_support": chat_type != DatabaseType.UNSUPPORTED,
"meta_column_type": meta_type.value,
"chat_column_type": chat_type.value,
"features": [],
"limitations": [],
"recommendations": [],
}
# Add database-specific features and limitations
if dialect_name == "sqlite":
compatibility["features"] = [
"JSON1 extension",
"Basic tag filtering",
"Message search",
]
compatibility["limitations"] = [
"No GIN indexes",
"Limited JSON optimization",
]
elif dialect_name == "postgresql":
compatibility["features"] = [
"Full JSON/JSONB support",
"GIN indexes",
"Advanced filtering",
]
if compatibility["jsonb_support"]:
compatibility["features"].append(
"JSONB binary format optimization"
)
return compatibility
except Exception as e:
log.error(f"Error checking database compatibility: {e}")
return {"error": str(e), "database_type": "unknown"}
def create_gin_indexes(self) -> bool:
"""Create GIN indexes on JSONB columns for better query performance"""
try:
with get_db() as db:
adapter = self._get_adapter(db)
# Only PostgreSQL supports GIN indexes
if db.bind.dialect.name != "postgresql":
return False
meta_type = adapter.get_database_type("meta")
chat_type = adapter.get_database_type("chat")
has_jsonb_meta = meta_type == DatabaseType.POSTGRESQL_JSONB
has_jsonb_chat = chat_type == DatabaseType.POSTGRESQL_JSONB
# Need at least one JSONB column for GIN indexes
if not (has_jsonb_meta or has_jsonb_chat):
return False
# Create GIN indexes for meta column if it's JSONB
if has_jsonb_meta:
try:
# General meta column index for all JSON operations
db.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_chat_meta_gin
ON chat USING GIN (meta)
"""
)
)
# Specific index for tags array operations
db.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_chat_meta_tags_gin
ON chat USING GIN ((meta->'tags'))
"""
)
)
# BTREE index for tag existence and count operations
db.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_chat_has_tags
ON chat USING BTREE ((meta ? 'tags' AND
jsonb_array_length(meta->'tags') > 0))
WHERE meta ? 'tags'
"""
)
)
except Exception:
# Continue if some indexes fail - partial optimization is better than none
pass
# Create GIN indexes for chat column if it's JSONB
if has_jsonb_chat:
try:
# General chat content index for all operations
db.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_chat_chat_gin
ON chat USING GIN (chat)
"""
)
)
# Specific index for message search operations
db.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_chat_messages_gin
ON chat USING GIN ((chat->'messages'))
"""
)
)
except Exception:
# Continue if some indexes fail
pass
# Commit all index creations
db.commit()
return True
except Exception:
return False
def check_gin_indexes(self) -> dict:
"""
Check which GIN indexes exist on the chat table.
Returns a dictionary with index names and their status.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return {"error": "GIN indexes are only supported on PostgreSQL"}
# Query PostgreSQL system catalog for GIN indexes on chat table
result = db.execute(
text(
"""
SELECT indexname, indexdef
FROM pg_indexes
WHERE tablename = 'chat'
AND indexdef LIKE '%USING gin%'
"""
)
)
# Process existing indexes
indexes = {}
for row in result:
indexes[row[0]] = {"exists": True, "definition": row[1]}
# Check for all expected indexes
expected_indexes = [
"idx_chat_meta_gin",
"idx_chat_chat_gin",
"idx_chat_meta_tags_gin",
"idx_chat_has_tags",
"idx_chat_tag_count",
"idx_chat_json_tags",
"idx_chat_messages_gin",
]
# Mark missing indexes
for idx_name in expected_indexes:
if idx_name not in indexes:
indexes[idx_name] = {"exists": False}
return indexes
except Exception as e:
log.error(f"Error checking GIN indexes: {e}")
return {"error": str(e)}
def drop_gin_indexes(self) -> bool:
"""Drop all GIN indexes on the chat table"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return False
# List of all indexes that might have been created
indexes_to_drop = [
"idx_chat_meta_gin",
"idx_chat_chat_gin",
"idx_chat_meta_tags_gin",
"idx_chat_has_tags",
"idx_chat_tag_count",
"idx_chat_json_tags",
"idx_chat_messages_gin",
]
# Drop each index with CONCURRENTLY to avoid blocking
for idx_name in indexes_to_drop:
try:
db.execute(
text(f"DROP INDEX CONCURRENTLY IF EXISTS {idx_name}")
)
except Exception:
# Continue dropping other indexes even if one fails
pass
db.commit()
return True
except Exception:
return False
def create_tag_indexes(self) -> bool:
"""
Create specialized indexes optimized specifically for tag operations.
This includes both GIN and BTREE indexes for different tag query patterns.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
log.info("Tag indexes are only supported on PostgreSQL")
return False
has_jsonb_meta = self._is_jsonb_column(db, "meta")
indexes_created = []
if has_jsonb_meta:
# JSONB-specific tag indexes for optimal performance
tag_indexes = [
{
"name": "idx_chat_meta_tags_gin",
"sql": (
"CREATE INDEX CONCURRENTLY IF NOT EXISTS "
"idx_chat_meta_tags_gin ON chat USING GIN ((meta->'tags'))"
),
"purpose": "Fast tag containment queries (@>, ?, etc.)",
},
{
"name": "idx_chat_has_tags",
"sql": (
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_has_tags "
"ON chat USING BTREE ((meta ? 'tags' AND "
"jsonb_array_length(meta->'tags') > 0)) WHERE meta ? 'tags'"
),
"purpose": "Fast filtering for chats with/without tags",
},
{
"name": "idx_chat_tag_count",
"sql": (
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_tag_count "
"ON chat USING BTREE ((jsonb_array_length(meta->'tags'))) "
"WHERE meta ? 'tags'"
),
"purpose": "Fast filtering by number of tags",
},
{
"name": "idx_chat_specific_tags",
"sql": (
"CREATE INDEX CONCURRENTLY IF NOT EXISTS "
"idx_chat_specific_tags ON chat USING GIN ((meta->'tags')) "
"WHERE jsonb_array_length(meta->'tags') > 0"
),
"purpose": "Optimized for chats that actually have tags",
},
]
else:
# JSON-specific tag indexes (less optimal but still helpful)
tag_indexes = [
{
"name": "idx_chat_json_tags",
"sql": (
"CREATE INDEX CONCURRENTLY IF NOT EXISTS "
"idx_chat_json_tags ON chat USING GIN ((meta->'tags')) "
"WHERE meta ? 'tags'"
),
"purpose": "Tag queries for JSON columns",
}
]
# Create each index with error handling
for index_info in tag_indexes:
try:
db.execute(text(index_info["sql"]))
indexes_created.append(
f"{index_info['name']} ({index_info['purpose']})"
)
log.info(f"Created tag index: {index_info['name']}")
except Exception as e:
log.warning(f"Failed to create {index_info['name']}: {e}")
db.commit()
# Log results
if indexes_created:
log.info(
f"Successfully created tag indexes: {len(indexes_created)} indexes"
)
for idx in indexes_created:
log.info(f"{idx}")
else:
log.info("No tag indexes were created")
return True
except Exception as e:
log.error(f"Error creating tag indexes: {e}")
return False
def optimize_tag_queries(self) -> dict:
"""
Analyze and provide recommendations for tag query optimization.
Returns statistics and suggestions for improving tag query performance.
"""
try:
with get_db() as db:
if db.bind.dialect.name != "postgresql":
return {"error": "Tag optimization is only supported on PostgreSQL"}
stats = {}
# Get comprehensive tag usage statistics
result = db.execute(
text(
"""
SELECT
COUNT(*) as total_chats,
COUNT(*) FILTER (WHERE meta ? 'tags') as chats_with_tags,
COUNT(*) FILTER (WHERE meta ? 'tags' AND
jsonb_array_length(meta->'tags') > 0) as chats_with_actual_tags,
AVG(CASE WHEN meta ? 'tags' THEN
jsonb_array_length(meta->'tags') ELSE 0 END) as avg_tags_per_chat
FROM chat
"""
)
)
row = result.fetchone()
if row:
stats.update(
{
"total_chats": row[0],
"chats_with_tags": row[1],
"chats_with_actual_tags": row[2],
"avg_tags_per_chat": float(row[3]) if row[3] else 0,
}
)
# Get most frequently used tags for optimization insights
result = db.execute(
text(
"""
SELECT tag_value, COUNT(*) as usage_count
FROM chat, jsonb_array_elements_text(meta->'tags') as tag_value
WHERE meta ? 'tags'
GROUP BY tag_value
ORDER BY usage_count DESC
LIMIT 10
"""
)
)
stats["top_tags"] = [{"tag": row[0], "count": row[1]} for row in result]
# Check current index status for recommendations
indexes = self.check_gin_indexes()
tag_indexes = {k: v for k, v in indexes.items() if "tag" in k.lower()}
stats["tag_indexes"] = tag_indexes
# Generate intelligent recommendations based on usage patterns
recommendations = []
if stats["chats_with_actual_tags"] > 1000:
recommendations.append(
"Consider creating tag-specific indexes for better performance"
)
if stats["avg_tags_per_chat"] > 5:
recommendations.append(
"High tag usage detected - GIN indexes will provide "
"significant benefits"
)
tag_coverage = (
stats["chats_with_actual_tags"] / stats["total_chats"]
if stats["total_chats"] > 0
else 0
)
if tag_coverage < 0.1:
recommendations.append(
"Low tag usage - consider partial indexes with WHERE clauses"
)
stats["recommendations"] = recommendations
stats["tag_coverage_percentage"] = round(tag_coverage * 100, 2)
return stats
except Exception as e:
log.error(f"Error analyzing tag queries: {e}")
return {"error": str(e)}
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def import_chat(
self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
with get_db() as db:
chat_item = db.get(Chat, id)
chat_item.chat = chat
chat_item.title = chat["title"] if "title" in chat else "New Chat"
chat_item.updated_at = int(time.time())
db.commit()
db.refresh(chat_item)
return ChatModel.model_validate(chat_item)
except Exception:
return None
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
chat["title"] = title
return self.update_chat_by_id(id, chat)
def update_chat_tags_by_id(
self, id: str, tags: list[str], user
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
# Remove all existing tags from this chat
self.delete_all_tags_by_id_and_user_id(id, user.id)
# Clean up orphaned tags (tags no longer used by any chat)
for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
# Add new tags to the chat
for tag_name in tags:
if tag_name.lower() == "none":
continue
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
return self.get_chat_by_id(id)
def get_chat_title_by_id(self, id: str) -> Optional[str]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("title", "New Chat")
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}) or {}
def get_message_by_id_and_message_id(
self, id: str, message_id: str
) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}).get(message_id, {})
def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
# Upsert the message (merge with existing or create new)
if message_id in history.get("messages", {}):
history["messages"][message_id] = {
**history["messages"][message_id],
**message,
}
else:
history["messages"][message_id] = message
# Update current message pointer for conversation flow
history["currentId"] = message_id
chat["history"] = history
return self.update_chat_by_id(id, chat)
def add_message_status_to_chat_by_id_and_message_id(
self, id: str, message_id: str, status: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
# Add status to message's status history if message exists
if message_id in history.get("messages", {}):
status_history = history["messages"][message_id].get("statusHistory", [])
status_history.append(status)
history["messages"][message_id]["statusHistory"] = status_history
chat["history"] = history
return self.update_chat_by_id(id, chat)
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share
chat = db.get(Chat, chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"created_at": chat.created_at,
"updated_at": int(time.time()),
}
)
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)
db.commit()
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, chat_id)
shared_chat = (
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
)
if shared_chat is None:
return self.insert_shared_chat_by_chat_id(chat_id)
shared_chat.title = chat.title
shared_chat.chat = chat.chat
shared_chat.updated_at = int(time.time())
db.commit()
db.refresh(shared_chat)
return ChatModel.model_validate(shared_chat)
except Exception:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit()
return True
except Exception:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.share_id = share_id
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.commit()
return True
except Exception:
return False
def get_archived_chat_list_by_user_id(
self,
user_id: str,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id, archived=True)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived:
query = query.filter_by(archived=False)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_title_id_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived:
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc()).with_entities(
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
)
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
# result has to be destructured from sqlalchemy `row` and mapped to a dict
# since the `ChatModel` is not the returned dataclass.
return [
ChatTitleIdResponse.model_validate(
{
"id": chat[0],
"title": chat[1],
"updated_at": chat[2],
"created_at": chat[3],
}
)
for chat in all_chats
]
def get_chat_list_by_chat_ids(
self, chat_ids: list[str], skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter(Chat.id.in_(chat_ids))
.filter_by(archived=False)
.order_by(Chat.updated_at.desc())
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checking if a chat with
# the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
return self.get_chat_by_id(id)
else:
return None
except Exception:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.order_by(Chat.updated_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text(
self,
user_id: str,
search_text: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination
using skip and limit.
"""
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(
user_id, include_archived, filter={}, skip=skip, limit=limit
)
search_text_words = search_text.split(" ")
# search_text might contain 'tag:tag_name' format so we need to extract
# the tag_name, split the search_text and remove the tags
tag_ids = [
normalize_tag_name(word.replace("tag:", ""))
for word in search_text_words
if word.startswith("tag:")
]
search_text_words = [
word for word in search_text_words if not word.startswith("tag:")
]
search_text = " ".join(search_text_words)
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived:
query = query.filter(Chat.archived == False)
query = query.order_by(Chat.updated_at.desc())
# Use adapter for cleaner query building
adapter = self._get_adapter(db)
# Add search filter if search text provided
if search_text:
search_filter = adapter.build_search_filter(search_text)
if search_filter is not None:
query = query.filter(
Chat.title.ilike(f"%{search_text}%") | search_filter
)
else:
# Fallback to title-only search for unsupported databases
query = query.filter(Chat.title.ilike(f"%{search_text}%"))
# Add tag filters
if "none" in tag_ids:
untagged_filter = adapter.build_untagged_filter("meta")
if untagged_filter is not None:
query = query.filter(untagged_filter)
elif tag_ids:
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter(
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.folder_id = folder_id
chat.updated_at = int(time.time())
chat.pinned = False
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = normalize_tag_name(tag_name)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def get_chats_by_multiple_tags(
self,
user_id: str,
tag_names: List[str],
match_all: bool = True,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
"""Get chats that match multiple tags"""
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
if not tag_names:
return []
# Normalize tag names
tag_ids = normalize_tag_names(tag_names)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", tag_ids, match_all=match_all)
if tag_filter is not None:
query = query.filter(tag_filter)
# Apply pagination and ordering
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def get_chats_without_tags(
self, user_id: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
"""Get chats that have no tags"""
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Use adapter to build untagged filter
untagged_filter = adapter.build_untagged_filter("meta")
if untagged_filter is not None:
query = query.filter(untagged_filter)
query = query.order_by(Chat.updated_at.desc()).offset(skip).limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
else:
return []
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
try:
with get_db() as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db:
adapter = self._get_adapter(db)
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = normalize_tag_name(tag_name)
# Use adapter to build tag filter
tag_filter = adapter.build_tag_filter("meta", [tag_id], match_all=True)
if tag_filter is not None:
query = query.filter(tag_filter)
return query.count()
else:
return 0
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = normalize_tag_name(tag_name)
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": list(set(tags)),
}
db.commit()
return True
except Exception:
return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.meta = {
**chat.meta,
"tags": [],
}
db.commit()
return True
except Exception:
return False
def delete_chat_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(id=id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except Exception:
return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except Exception:
return False
def delete_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
self.delete_shared_chats_by_user_id(user_id)
db.query(Chat).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception:
return False
def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
db.commit()
return True
except Exception:
return False
Chats = ChatTable()