wip: prompt history models
This commit is contained in:
223
backend/open_webui/models/prompt_history.py
Normal file
223
backend/open_webui/models/prompt_history.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Prompt history model for version tracking."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
import json
|
||||
import difflib
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, get_db_context
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Index
|
||||
|
||||
|
||||
####################
|
||||
# PromptHistory DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class PromptHistory(Base):
|
||||
__tablename__ = "prompt_history"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
prompt_id = Column(Text, nullable=False, index=True)
|
||||
parent_id = Column(Text, nullable=True) # Reference to parent commit
|
||||
snapshot = Column(JSON, nullable=False)
|
||||
user_id = Column(Text, nullable=False)
|
||||
commit_message = Column(Text, nullable=True)
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
|
||||
|
||||
class PromptHistoryModel(BaseModel):
|
||||
id: str
|
||||
prompt_id: str
|
||||
parent_id: Optional[str] = None
|
||||
snapshot: dict
|
||||
user_id: str
|
||||
commit_message: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PromptHistoryResponse(PromptHistoryModel):
|
||||
"""Response model with user info."""
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class PromptHistoryTable:
|
||||
def create_history_entry(
|
||||
self,
|
||||
prompt_id: str,
|
||||
snapshot: dict,
|
||||
user_id: str,
|
||||
parent_id: Optional[str] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[PromptHistoryModel]:
|
||||
"""Create a new history entry (commit) for a prompt."""
|
||||
with get_db_context(db) as db:
|
||||
history = PromptHistory(
|
||||
id=str(uuid.uuid4()),
|
||||
prompt_id=prompt_id,
|
||||
parent_id=parent_id,
|
||||
snapshot=snapshot,
|
||||
user_id=user_id,
|
||||
commit_message=commit_message,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
db.add(history)
|
||||
db.commit()
|
||||
db.refresh(history)
|
||||
return PromptHistoryModel.model_validate(history)
|
||||
|
||||
def get_history_by_prompt_id(
|
||||
self,
|
||||
prompt_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[PromptHistoryResponse]:
|
||||
"""Get all history entries for a prompt, ordered by created_at desc."""
|
||||
with get_db_context(db) as db:
|
||||
entries = (
|
||||
db.query(PromptHistory)
|
||||
.filter(PromptHistory.prompt_id == prompt_id)
|
||||
.order_by(PromptHistory.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get user info for each entry
|
||||
user_ids = list(set(e.user_id for e in entries))
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
|
||||
return [
|
||||
PromptHistoryResponse(
|
||||
**PromptHistoryModel.model_validate(entry).model_dump(),
|
||||
user=users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None,
|
||||
)
|
||||
for entry in entries
|
||||
]
|
||||
|
||||
def get_history_entry_by_id(
|
||||
self,
|
||||
history_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[PromptHistoryModel]:
|
||||
"""Get a specific history entry by ID."""
|
||||
with get_db_context(db) as db:
|
||||
entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
|
||||
if entry:
|
||||
return PromptHistoryModel.model_validate(entry)
|
||||
return None
|
||||
|
||||
def get_latest_history_entry(
|
||||
self,
|
||||
prompt_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[PromptHistoryModel]:
|
||||
"""Get the most recent history entry for a prompt."""
|
||||
with get_db_context(db) as db:
|
||||
entry = (
|
||||
db.query(PromptHistory)
|
||||
.filter(PromptHistory.prompt_id == prompt_id)
|
||||
.order_by(PromptHistory.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if entry:
|
||||
return PromptHistoryModel.model_validate(entry)
|
||||
return None
|
||||
|
||||
def get_history_count(
|
||||
self,
|
||||
prompt_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> int:
|
||||
"""Get the number of history entries for a prompt."""
|
||||
with get_db_context(db) as db:
|
||||
return (
|
||||
db.query(PromptHistory)
|
||||
.filter(PromptHistory.prompt_id == prompt_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
def compute_diff(
|
||||
self,
|
||||
from_id: str,
|
||||
to_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Compute diff between two history entries."""
|
||||
with get_db_context(db) as db:
|
||||
from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
|
||||
to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first()
|
||||
|
||||
if not from_entry or not to_entry:
|
||||
return None
|
||||
|
||||
from_snapshot = from_entry.snapshot
|
||||
to_snapshot = to_entry.snapshot
|
||||
|
||||
# Compute diff for content field
|
||||
from_content = from_snapshot.get("content", "")
|
||||
to_content = to_snapshot.get("content", "")
|
||||
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
from_content.splitlines(keepends=True),
|
||||
to_content.splitlines(keepends=True),
|
||||
fromfile=f"v{from_id[:8]}",
|
||||
tofile=f"v{to_id[:8]}",
|
||||
lineterm="",
|
||||
))
|
||||
|
||||
return {
|
||||
"from_id": from_id,
|
||||
"to_id": to_id,
|
||||
"from_snapshot": from_snapshot,
|
||||
"to_snapshot": to_snapshot,
|
||||
"content_diff": diff_lines,
|
||||
"name_changed": from_snapshot.get("name") != to_snapshot.get("name"),
|
||||
"access_control_changed": from_snapshot.get("access_control") != to_snapshot.get("access_control"),
|
||||
}
|
||||
|
||||
def delete_history_by_prompt_id(
|
||||
self,
|
||||
prompt_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> bool:
|
||||
"""Delete all history entries for a prompt."""
|
||||
with get_db_context(db) as db:
|
||||
db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_history_entry(
|
||||
self,
|
||||
history_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> bool:
|
||||
"""Delete a history entry and reparent its children to grandparent."""
|
||||
with get_db_context(db) as db:
|
||||
entry = db.query(PromptHistory).filter_by(id=history_id).first()
|
||||
if not entry:
|
||||
return False
|
||||
|
||||
# Find children that reference this entry as parent
|
||||
children = db.query(PromptHistory).filter_by(parent_id=history_id).all()
|
||||
|
||||
# Reparent children to grandparent
|
||||
for child in children:
|
||||
child.parent_id = entry.parent_id
|
||||
|
||||
db.delete(entry)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
PromptHistories = PromptHistoryTable()
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -7,7 +8,7 @@ from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
@@ -19,11 +20,17 @@ from open_webui.utils.access_control import has_access
|
||||
class Prompt(Base):
|
||||
__tablename__ = "prompt"
|
||||
|
||||
command = Column(String, primary_key=True)
|
||||
id = Column(Text, primary_key=True)
|
||||
command = Column(String, unique=True, index=True)
|
||||
user_id = Column(String)
|
||||
title = Column(Text)
|
||||
name = Column(Text)
|
||||
content = Column(Text)
|
||||
timestamp = Column(BigInteger)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
version_id = Column(Text, nullable=True) # Points to active history entry
|
||||
created_at = Column(BigInteger, nullable=True)
|
||||
updated_at = Column(BigInteger, nullable=True)
|
||||
|
||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||
# Defines access control rules for this entry.
|
||||
@@ -44,13 +51,19 @@ class Prompt(Base):
|
||||
|
||||
|
||||
class PromptModel(BaseModel):
|
||||
id: Optional[str] = None
|
||||
command: str
|
||||
user_id: str
|
||||
title: str
|
||||
name: str
|
||||
content: str
|
||||
timestamp: int # timestamp in epoch
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
is_active: Optional[bool] = True
|
||||
version_id: Optional[str] = None
|
||||
created_at: Optional[int] = None
|
||||
updated_at: Optional[int] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@@ -69,21 +82,34 @@ class PromptAccessResponse(PromptUserResponse):
|
||||
|
||||
class PromptForm(BaseModel):
|
||||
command: str
|
||||
title: str
|
||||
name: str # Changed from title
|
||||
content: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
version_id: Optional[str] = None # Active version
|
||||
commit_message: Optional[str] = None # For history tracking
|
||||
|
||||
|
||||
class PromptsTable:
|
||||
def insert_new_prompt(
|
||||
self, user_id: str, form_data: PromptForm, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
now = int(time.time())
|
||||
prompt_id = str(uuid.uuid4())
|
||||
|
||||
prompt = PromptModel(
|
||||
**{
|
||||
"user_id": user_id,
|
||||
**form_data.model_dump(),
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
id=prompt_id,
|
||||
user_id=user_id,
|
||||
command=form_data.command,
|
||||
name=form_data.name,
|
||||
content=form_data.content,
|
||||
data=form_data.data or {},
|
||||
meta=form_data.meta or {},
|
||||
access_control=form_data.access_control,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -92,26 +118,74 @@ class PromptsTable:
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
if result:
|
||||
# Create initial history entry
|
||||
from open_webui.models.prompt_history import PromptHistories
|
||||
|
||||
snapshot = {
|
||||
"name": form_data.name,
|
||||
"content": form_data.content,
|
||||
"command": form_data.command,
|
||||
"data": form_data.data or {},
|
||||
"meta": form_data.meta or {},
|
||||
"access_control": form_data.access_control,
|
||||
}
|
||||
|
||||
history_entry = PromptHistories.create_history_entry(
|
||||
prompt_id=prompt_id,
|
||||
snapshot=snapshot,
|
||||
user_id=user_id,
|
||||
parent_id=None, # Initial commit has no parent
|
||||
commit_message=form_data.commit_message or "Initial version",
|
||||
db=db,
|
||||
)
|
||||
|
||||
# Set the initial version as the production version
|
||||
if history_entry:
|
||||
result.version_id = history_entry.id
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
return PromptModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompt_by_id(
|
||||
self, prompt_id: str, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
"""Get prompt by UUID."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(id=prompt_id).first()
|
||||
if prompt:
|
||||
return PromptModel.model_validate(prompt)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompt_by_command(
|
||||
self, command: str, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
return PromptModel.model_validate(prompt)
|
||||
if prompt:
|
||||
return PromptModel.model_validate(prompt)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
|
||||
all_prompts = (
|
||||
db.query(Prompt)
|
||||
.filter(Prompt.is_active == True)
|
||||
.order_by(Prompt.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
user_ids = list(set(prompt.user_id for prompt in all_prompts))
|
||||
|
||||
@@ -148,16 +222,101 @@ class PromptsTable:
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm, db: Optional[Session] = None
|
||||
self,
|
||||
command: str,
|
||||
form_data: PromptForm,
|
||||
user_id: str,
|
||||
db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
prompt.title = form_data.title
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
# Get the latest history entry for parent_id
|
||||
from open_webui.models.prompt_history import PromptHistories
|
||||
latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
|
||||
parent_id = latest_history.id if latest_history else None
|
||||
|
||||
# Check if content changed to decide on history creation
|
||||
content_changed = (
|
||||
prompt.name != form_data.name or
|
||||
prompt.content != form_data.content or
|
||||
prompt.access_control != form_data.access_control
|
||||
)
|
||||
|
||||
# Update prompt fields
|
||||
prompt.name = form_data.name
|
||||
prompt.content = form_data.content
|
||||
prompt.data = form_data.data or prompt.data
|
||||
prompt.meta = form_data.meta or prompt.meta
|
||||
prompt.access_control = form_data.access_control
|
||||
prompt.timestamp = int(time.time())
|
||||
if form_data.version_id is not None:
|
||||
prompt.version_id = form_data.version_id
|
||||
prompt.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
# Create history entry only if content changed
|
||||
if content_changed:
|
||||
snapshot = {
|
||||
"name": form_data.name,
|
||||
"content": form_data.content,
|
||||
"command": command,
|
||||
"data": form_data.data or {},
|
||||
"meta": form_data.meta or {},
|
||||
"access_control": form_data.access_control,
|
||||
}
|
||||
|
||||
PromptHistories.create_history_entry(
|
||||
prompt_id=prompt.id,
|
||||
snapshot=snapshot,
|
||||
user_id=user_id,
|
||||
parent_id=parent_id,
|
||||
commit_message=form_data.commit_message,
|
||||
db=db,
|
||||
)
|
||||
|
||||
return PromptModel.model_validate(prompt)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def update_prompt_version(
|
||||
self,
|
||||
command: str,
|
||||
version_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[PromptModel]:
|
||||
"""Set the active version of a prompt and restore content from that version's snapshot."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
# Get the history entry to restore content from
|
||||
from open_webui.models.prompt_history import PromptHistories
|
||||
history_entry = PromptHistories.get_history_entry_by_id(version_id, db=db)
|
||||
|
||||
if not history_entry:
|
||||
return None
|
||||
|
||||
# Restore prompt content from the snapshot
|
||||
snapshot = history_entry.snapshot
|
||||
if snapshot:
|
||||
prompt.name = snapshot.get("name", prompt.name)
|
||||
prompt.content = snapshot.get("content", prompt.content)
|
||||
prompt.data = snapshot.get("data", prompt.data)
|
||||
prompt.meta = snapshot.get("meta", prompt.meta)
|
||||
# Note: command and access_control are not restored from snapshot
|
||||
|
||||
prompt.version_id = version_id
|
||||
prompt.updated_at = int(time.time())
|
||||
db.commit()
|
||||
|
||||
return PromptModel.model_validate(prompt)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -165,12 +324,40 @@ class PromptsTable:
|
||||
def delete_prompt_by_command(
|
||||
self, command: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
"""Soft delete a prompt by setting is_active to False."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Prompt).filter_by(command=command).delete()
|
||||
db.commit()
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
if prompt:
|
||||
# Delete history first (Requirement: entire history should be deleted)
|
||||
from open_webui.models.prompt_history import PromptHistories
|
||||
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||
|
||||
return True
|
||||
prompt.is_active = False
|
||||
prompt.updated_at = int(time.time())
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def hard_delete_prompt_by_command(
|
||||
self, command: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
"""Permanently delete a prompt and its history."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
if prompt:
|
||||
# Delete history first
|
||||
from open_webui.models.prompt_history import PromptHistories
|
||||
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||
|
||||
# Delete prompt
|
||||
db.query(Prompt).filter_by(command=command).delete()
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user