feat(sqlalchemy): Replace peewee with sqlalchemy

This commit is contained in:
Jonathan Rohde
2024-06-18 15:03:31 +02:00
parent 8dac2a2140
commit df09d0830a
47 changed files with 2580 additions and 1003 deletions

View File

@@ -1,14 +1,14 @@
from pydantic import BaseModel
from typing import List, Union, Optional
import time
from typing import Optional
import uuid
import logging
from peewee import *
from sqlalchemy import String, Column, Boolean
from sqlalchemy.orm import Session
from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base
from config import SRC_LOG_LEVELS
@@ -20,14 +20,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Auth(Model):
id = CharField(unique=True)
email = CharField()
password = TextField()
active = BooleanField()
class Auth(Base):
__tablename__ = "auth"
class Meta:
database = DB
id = Column(String, primary_key=True)
email = Column(String)
password = Column(String)
active = Column(Boolean)
class AuthModel(BaseModel):
@@ -94,12 +93,10 @@ class AddUserForm(SignupForm):
class AuthsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Auth])
def insert_new_auth(
self,
db: Session,
email: str,
password: str,
name: str,
@@ -114,24 +111,30 @@ class AuthsTable:
auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True}
)
result = Auth.create(**auth.model_dump())
result = Auth(**auth.model_dump())
db.add(result)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
db, id, name, email, profile_image_url, role, oauth_sub
)
db.commit()
db.refresh(result)
if result and user:
return user
else:
return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
def authenticate_user(
self, db: Session, email: str, password: str
) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth:
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
user = Users.get_user_by_id(db, auth.id)
return user
else:
return None
@@ -140,55 +143,55 @@ class AuthsTable:
except:
return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
def authenticate_user_by_api_key(
self, db: Session, api_key: str
) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None
if not api_key:
return None
try:
user = Users.get_user_by_api_key(api_key)
user = Users.get_user_by_api_key(db, api_key)
return user if user else None
except:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
def authenticate_user_by_trusted_header(
self, db: Session, email: str
) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
auth = db.query(Auth).filter(email=email, active=True).first()
if auth:
user = Users.get_user_by_id(auth.id)
return user
except:
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
def update_user_password_by_id(
self, db: Session, id: str, new_password: str
) -> bool:
try:
query = Auth.update(password=new_password).where(Auth.id == id)
result = query.execute()
result = db.query(Auth).filter_by(id=id).update({"password": new_password})
return True if result == 1 else False
except:
return False
def update_email_by_id(self, id: str, email: str) -> bool:
def update_email_by_id(self, db: Session, id: str, email: str) -> bool:
try:
query = Auth.update(email=email).where(Auth.id == id)
result = query.execute()
result = db.query(Auth).filter_by(id=id).update({"email": email})
return True if result == 1 else False
except:
return False
def delete_auth_by_id(self, id: str) -> bool:
def delete_auth_by_id(self, db: Session, id: str) -> bool:
try:
# Delete User
result = Users.delete_user_by_id(id)
result = Users.delete_user_by_id(db, id)
if result:
# Delete Auth
query = Auth.delete().where(Auth.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Auth).filter_by(id=id).delete()
return True
else:
@@ -197,4 +200,4 @@ class AuthsTable:
return False
Auths = AuthsTable(DB)
Auths = AuthsTable()

View File

@@ -1,36 +1,39 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
from apps.webui.internal.db import DB
from sqlalchemy import Column, String, BigInteger, Boolean
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base
####################
# Chat DB Schema
####################
class Chat(Model):
id = CharField(unique=True)
user_id = CharField()
title = TextField()
chat = TextField() # Save Chat JSON as Text
class Chat(Base):
__tablename__ = "chat"
created_at = BigIntegerField()
updated_at = BigIntegerField()
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(String)
chat = Column(String) # Save Chat JSON as Text
share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class Meta:
database = DB
share_id = Column(String, unique=True, nullable=True)
archived = Column(Boolean, default=False)
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
title: str
@@ -75,11 +78,10 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def __init__(self, db):
self.db = db
db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
def insert_new_chat(
self, db: Session, user_id: str, form_data: ChatForm
) -> Optional[ChatModel]:
id = str(uuid.uuid4())
chat = ChatModel(
**{
@@ -94,29 +96,36 @@ class ChatTable:
}
)
result = Chat.create(**chat.model_dump())
return chat if result else None
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
def update_chat_by_id(
self, db: Session, id: str, chat: dict
) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
updated_at=int(time.time()),
).where(Chat.id == id)
query.execute()
db.query(Chat).filter_by(id=id).update(
{
"chat": json.dumps(chat),
"title": chat["title"] if "title" in chat else "New Chat",
"updated_at": int(time.time()),
}
)
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(db, id)
except:
return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
def insert_shared_chat_by_chat_id(
self, db: Session, chat_id: str
) -> Optional[ChatModel]:
# Get the existing chat to share
chat = Chat.get(Chat.id == chat_id)
chat = db.get(Chat, chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
@@ -128,228 +137,196 @@ class ChatTable:
"updated_at": int(time.time()),
}
)
shared_result = Chat.create(**shared_chat.model_dump())
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
)
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
def update_shared_chat_by_chat_id(
self, db: Session, chat_id: str
) -> Optional[ChatModel]:
try:
print("update_shared_chat_by_id")
chat = Chat.get(Chat.id == chat_id)
chat = db.get(Chat, chat_id)
print(chat)
query = Chat.update(
title=chat.title,
chat=chat.chat,
).where(Chat.id == chat.share_id)
db.query(Chat).filter_by(id=chat.share_id).update(
{"title": chat.title, "chat": chat.chat}
)
query.execute()
chat = Chat.get(Chat.id == chat.share_id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(db, chat.share_id)
except:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool:
try:
query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
return True
except:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
self, db: Session, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
query = Chat.update(
share_id=share_id,
).where(Chat.id == id)
query.execute()
db.query(Chat).filter_by(id=id).update({"share_id": share_id})
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(db, id)
except:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
try:
chat = self.get_chat_by_id(id)
query = Chat.update(
archived=(not chat.archived),
).where(Chat.id == id)
chat = self.get_chat_by_id(db, id)
db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(db, id)
except:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool:
try:
chats = self.get_chats_by_user_id(user_id)
for chat in chats:
query = Chat.update(
archived=True,
).where(Chat.id == chat.id)
query.execute()
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
return True
except:
return False
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
self, db: Session, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
# .limit(limit).offset(skip)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
db: Session,
user_id: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 50,
) -> List[ChatModel]:
if include_archived:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
else:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived:
query = query.filter_by(archived=False)
all_chats = (
query.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50
self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.id.in_(chat_ids))
all_chats = (
db.query(Chat)
.filter(Chat.id.in_(chat_ids))
.filter_by(archived=False)
.order_by(Chat.updated_at.desc())
]
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except:
return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.share_id == id)
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(db, id)
else:
return None
except Exception as e:
return None
def get_chat_by_id_and_user_id(
self, db: Session, id: str, user_id: str
) -> Optional[ChatModel]:
try:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select().order_by(Chat.updated_at.desc())
def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]:
all_chats = (
db.query(Chat)
# .limit(limit).offset(skip)
]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
]
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]:
all_chats = (
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(
self, db: Session, user_id: str
) -> List[ChatModel]:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
]
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def delete_chat_by_id(self, id: str) -> bool:
def delete_chat_by_id(self, db: Session, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(id=id).delete()
return True and self.delete_shared_chat_by_chat_id(id)
return True and self.delete_shared_chat_by_chat_id(db, id)
except:
return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
return True and self.delete_shared_chat_by_chat_id(id)
return True and self.delete_shared_chat_by_chat_id(db, id)
except:
return False
def delete_chats_by_user_id(self, user_id: str) -> bool:
def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool:
try:
self.delete_shared_chats_by_user_id(user_id)
query = Chat.delete().where(Chat.user_id == user_id)
query.execute() # Remove the rows, return number of rows removed.
self.delete_shared_chats_by_user_id(db, user_id)
db.query(Chat).filter_by(user_id=user_id).delete()
return True
except:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool:
try:
shared_chat_ids = [
f"shared-{chat.id}"
for chat in Chat.select().where(Chat.user_id == user_id)
]
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
query = Chat.delete().where(Chat.user_id << shared_chat_ids)
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
return True
except:
return False
Chats = ChatTable(DB)
Chats = ChatTable()

View File

@@ -1,14 +1,12 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base
import json
@@ -22,20 +20,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Document(Model):
collection_name = CharField(unique=True)
name = CharField(unique=True)
title = TextField()
filename = TextField()
content = TextField(null=True)
user_id = CharField()
timestamp = BigIntegerField()
class Document(Base):
__tablename__ = "document"
class Meta:
database = DB
collection_name = Column(String, primary_key=True)
name = Column(String, unique=True)
title = Column(String)
filename = Column(String)
content = Column(String, nullable=True)
user_id = Column(String)
timestamp = Column(BigInteger)
class DocumentModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
collection_name: str
name: str
title: str
@@ -72,12 +71,9 @@ class DocumentForm(DocumentUpdateForm):
class DocumentsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Document])
def insert_new_doc(
self, user_id: str, form_data: DocumentForm
self, db: Session, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]:
document = DocumentModel(
**{
@@ -88,73 +84,69 @@ class DocumentsTable:
)
try:
result = Document.create(**document.model_dump())
result = Document(**document.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return document
return DocumentModel.model_validate(result)
else:
return None
except:
return None
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]:
try:
document = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(document))
document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None
except:
return None
def get_docs(self) -> List[DocumentModel]:
return [
DocumentModel(**model_to_dict(doc))
for doc in Document.select()
# .limit(limit).offset(skip)
]
def get_docs(self, db: Session) -> List[DocumentModel]:
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm
self, db: Session, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]:
try:
query = Document.update(
title=form_data.title,
name=form_data.name,
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc))
db.query(Document).filter_by(name=name).update(
{
"title": form_data.title,
"name": form_data.name,
"timestamp": int(time.time()),
}
)
return self.get_doc_by_name(db, form_data.name)
except Exception as e:
log.exception(e)
return None
def update_doc_content_by_name(
self, name: str, updated: dict
self, db: Session, name: str, updated: dict
) -> Optional[DocumentModel]:
try:
doc = self.get_doc_by_name(name)
doc = self.get_doc_by_name(db, name)
doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
query = Document.update(
content=json.dumps(doc_content),
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
db.query(Document).filter_by(name=name).update(
{
"content": json.dumps(doc_content),
"timestamp": int(time.time()),
}
)
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
return self.get_doc_by_name(db, name)
except Exception as e:
log.exception(e)
return None
def delete_doc_by_name(self, name: str) -> bool:
def delete_doc_by_name(self, db: Session, name: str) -> bool:
try:
query = Document.delete().where((Document.name == name))
query.execute() # Remove the rows, return number of rows removed.
db.query(Document).filter_by(name=name).delete()
return True
except:
return False
Documents = DocumentsTable(DB)
Documents = DocumentsTable()

View File

@@ -1,10 +1,12 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base
import json
@@ -18,15 +20,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class File(Model):
id = CharField(unique=True)
user_id = CharField()
filename = TextField()
meta = JSONField()
created_at = BigIntegerField()
class File(Base):
__tablename__ = "file"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
filename = Column(String)
meta = Column(JSONField)
created_at = Column(BigInteger)
class FileModel(BaseModel):
@@ -36,6 +37,7 @@ class FileModel(BaseModel):
meta: dict
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@@ -57,11 +59,8 @@ class FileForm(BaseModel):
class FilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([File])
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]:
file = FileModel(
**{
**form_data.model_dump(),
@@ -71,42 +70,41 @@ class FilesTable:
)
try:
result = File.create(**file.model_dump())
result = File(**file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return file
return FileModel.model_validate(result)
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:
def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]:
try:
file = File.get(File.id == id)
return FileModel(**model_to_dict(file))
file = db.get(File, id)
return FileModel.model_validate(file)
except:
return None
def get_files(self) -> List[FileModel]:
return [FileModel(**model_to_dict(file)) for file in File.select()]
def get_files(self, db: Session) -> List[FileModel]:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, id: str) -> bool:
def delete_file_by_id(self, db: Session, id: str) -> bool:
try:
query = File.delete().where((File.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(File).filter_by(id=id).delete()
return True
except:
return False
def delete_all_files(self) -> bool:
def delete_all_files(self, db: Session) -> bool:
try:
query = File.delete()
query.execute() # Remove the rows, return number of rows removed.
db.query(File).delete()
return True
except:
return False
Files = FilesTable(DB)
Files = FilesTable()

View File

@@ -1,10 +1,12 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, Text, BigInteger, Boolean
from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base
from apps.webui.models.users import Users
import json
@@ -21,20 +23,19 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Function(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
type = TextField()
content = TextField()
meta = JSONField()
valves = JSONField()
is_active = BooleanField(default=False)
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Function(Base):
__tablename__ = "function"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
type = Column(Text)
content = Column(Text)
meta = Column(JSONField)
valves = Column(JSONField)
is_active = Column(Boolean)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class FunctionMeta(BaseModel):
@@ -53,6 +54,8 @@ class FunctionModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@@ -82,12 +85,9 @@ class FunctionValves(BaseModel):
class FunctionsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Function])
def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm
self, db: Session, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]:
function = FunctionModel(
**{
@@ -100,19 +100,22 @@ class FunctionsTable:
)
try:
result = Function.create(**function.model_dump())
result = Function(**function.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return function
return FunctionModel.model_validate(result)
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]:
try:
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
function = db.get(Function, id)
return FunctionModel.model_validate(function)
except:
return None
@@ -211,14 +214,11 @@ class FunctionsTable:
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try:
query = Function.update(
db.query(Function).filter_by(id=id).update({
**updated,
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
"updated_at": int(time.time()),
})
return self.get_function_by_id(db, id)
except:
return None
@@ -235,14 +235,12 @@ class FunctionsTable:
except:
return None
def delete_function_by_id(self, id: str) -> bool:
def delete_function_by_id(self, db: Session, id: str) -> bool:
try:
query = Function.delete().where((Function.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(Function).filter_by(id=id).delete()
return True
except:
return False
Functions = FunctionsTable(DB)
Functions = FunctionsTable()

View File

@@ -1,9 +1,10 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
from apps.webui.internal.db import DB
from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base
from apps.webui.models.chats import Chats
import time
@@ -14,15 +15,14 @@ import uuid
####################
class Memory(Model):
id = CharField(unique=True)
user_id = CharField()
content = TextField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Memory(Base):
__tablename__ = "memory"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
content = Column(String)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class MemoryModel(BaseModel):
@@ -32,6 +32,8 @@ class MemoryModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@@ -39,12 +41,10 @@ class MemoryModel(BaseModel):
class MemoriesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Memory])
def insert_new_memory(
self,
db: Session,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
@@ -59,74 +59,73 @@ class MemoriesTable:
"updated_at": int(time.time()),
}
)
result = Memory.create(**memory.model_dump())
result = Memory(**memory.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return memory
return MemoryModel.model_validate(result)
else:
return None
def update_memory_by_id(
self,
db: Session,
id: str,
content: str,
) -> Optional[MemoryModel]:
try:
memory = Memory.get(Memory.id == id)
memory.content = content
memory.updated_at = int(time.time())
memory.save()
return MemoryModel(**model_to_dict(memory))
db.query(Memory).filter_by(id=id).update(
{"content": content, "updated_at": int(time.time())}
)
return self.get_memory_by_id(db, id)
except:
return None
def get_memories(self) -> List[MemoryModel]:
def get_memories(self, db: Session) -> List[MemoryModel]:
try:
memories = Memory.select()
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except:
return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]:
try:
memories = Memory.select().where(Memory.user_id == user_id)
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except:
return None
def get_memory_by_id(self, id) -> Optional[MemoryModel]:
def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]:
try:
memory = Memory.get(Memory.id == id)
return MemoryModel(**model_to_dict(memory))
memory = db.get(Memory, id)
return MemoryModel.model_validate(memory)
except:
return None
def delete_memory_by_id(self, id: str) -> bool:
def delete_memory_by_id(self, db: Session, id: str) -> bool:
try:
query = Memory.delete().where(Memory.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Memory).filter_by(id=id).delete()
return True
except:
return False
def delete_memories_by_user_id(self, user_id: str) -> bool:
def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
try:
query = Memory.delete().where(Memory.user_id == user_id)
query.execute()
db.query(Memory).filter_by(user_id=user_id).delete()
return True
except:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
def delete_memory_by_id_and_user_id(
self, db: Session, id: str, user_id: str
) -> bool:
try:
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
query.execute()
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
return True
except:
return False
Memories = MemoriesTable(DB)
Memories = MemoriesTable()

View File

@@ -2,13 +2,11 @@ import json
import logging
from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import DB, JSONField
from apps.webui.internal.db import Base, JSONField
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
@@ -46,41 +44,42 @@ class ModelMeta(BaseModel):
pass
class Model(pw.Model):
id = pw.TextField(unique=True)
class Model(Base):
__tablename__ = "model"
id = Column(String, primary_key=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id = pw.TextField()
user_id = Column(String)
base_model_id = pw.TextField(null=True)
base_model_id = Column(String, nullable=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name = pw.TextField()
name = Column(String)
"""
The human-readable display name of the model.
"""
params = JSONField()
params = Column(JSONField)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta = JSONField()
meta = Column(JSONField)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ModelModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
base_model_id: Optional[str] = None
@@ -115,15 +114,9 @@ class ModelForm(BaseModel):
class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model(
self, form_data: ModelForm, user_id: str
self, db: Session, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
model = ModelModel(
**{
@@ -134,46 +127,50 @@ class ModelsTable:
}
)
try:
result = Model.create(**model.model_dump())
result = Model(**model.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return model
return ModelModel.model_validate(result)
else:
return None
except Exception as e:
print(e)
return None
def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
def get_all_models(self, db: Session) -> List[ModelModel]:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]:
try:
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
model = db.get(Model, id)
return ModelModel.model_validate(model)
except:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
def update_model_by_id(
self, db: Session, id: str, model: ModelForm
) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
model = db.query(Model).get(id)
model.update(**model.model_dump())
db.commit()
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
print(e)
return None
def delete_model_by_id(self, id: str) -> bool:
def delete_model_by_id(self, db: Session, id: str) -> bool:
try:
query = Model.delete().where(Model.id == id)
query.execute()
db.query(Model).filter_by(id=id).delete()
return True
except:
return False
Models = ModelsTable(DB)
Models = ModelsTable()

View File

@@ -1,13 +1,11 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base
import json
@@ -16,15 +14,14 @@ import json
####################
class Prompt(Model):
command = CharField(unique=True)
user_id = CharField()
title = TextField()
content = TextField()
timestamp = BigIntegerField()
class Prompt(Base):
__tablename__ = "prompt"
class Meta:
database = DB
command = Column(String, primary_key=True)
user_id = Column(String)
title = Column(String)
content = Column(String)
timestamp = Column(BigInteger)
class PromptModel(BaseModel):
@@ -34,6 +31,8 @@ class PromptModel(BaseModel):
content: str
timestamp: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@@ -48,12 +47,8 @@ class PromptForm(BaseModel):
class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
self, db: Session, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
prompt = PromptModel(
**{
@@ -66,53 +61,48 @@ class PromptsTable:
)
try:
result = Prompt.create(**prompt.model_dump())
result = Prompt(**prompt.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return prompt
return PromptModel.model_validate(result)
else:
return None
except:
except Exception as e:
return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]:
try:
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt)
except:
return None
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def get_prompts(self, db: Session) -> List[PromptModel]:
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
self, db: Session, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,
content=form_data.content,
timestamp=int(time.time()),
).where(Prompt.command == command)
query.execute()
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
db.query(Prompt).filter_by(command=command).update(
{
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
}
)
return self.get_prompt_by_command(db, command)
except:
return None
def delete_prompt_by_command(self, command: str) -> bool:
def delete_prompt_by_command(self, db: Session, command: str) -> bool:
try:
query = Prompt.delete().where((Prompt.command == command))
query.execute() # Remove the rows, return number of rows removed.
db.query(Prompt).filter_by(command=command).delete()
return True
except:
return False
Prompts = PromptsTable(DB)
Prompts = PromptsTable()

View File

@@ -1,14 +1,15 @@
from pydantic import BaseModel
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import json
import uuid
import time
import logging
from apps.webui.internal.db import DB
from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base
from config import SRC_LOG_LEVELS
@@ -20,25 +21,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tag(Model):
id = CharField(unique=True)
name = CharField()
user_id = CharField()
data = TextField(null=True)
class Tag(Base):
__tablename__ = "tag"
class Meta:
database = DB
id = Column(String, primary_key=True)
name = Column(String)
user_id = Column(String)
data = Column(String, nullable=True)
class ChatIdTag(Model):
id = CharField(unique=True)
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = BigIntegerField()
class ChatIdTag(Base):
__tablename__ = "chatidtag"
class Meta:
database = DB
id = Column(String, primary_key=True)
tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel):
@@ -47,6 +46,8 @@ class TagModel(BaseModel):
user_id: str
data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel):
id: str
@@ -55,6 +56,8 @@ class ChatIdTagModel(BaseModel):
user_id: str
timestamp: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@@ -75,37 +78,39 @@ class ChatTagsResponse(BaseModel):
class TagTable:
def __init__(self, db):
self.db = db
db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
def insert_new_tag(
self, db: Session, name: str, user_id: str
) -> Optional[TagModel]:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag.create(**tag.model_dump())
result = Tag(**tag.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return tag
return TagModel.model_validate(result)
else:
return None
except Exception as e:
return None
def get_tag_by_name_and_user_id(
self, name: str, user_id: str
self, db: Session, name: str, user_id: str
) -> Optional[TagModel]:
try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
return TagModel(**model_to_dict(tag))
tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception as e:
return None
def add_tag_to_chat(
self, user_id: str, form_data: ChatIdTagForm
self, db: Session, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id)
if tag == None:
tag = self.insert_new_tag(form_data.tag_name, user_id)
tag = self.insert_new_tag(db, form_data.tag_name, user_id)
id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel(
@@ -118,120 +123,135 @@ class TagTable:
}
)
try:
result = ChatIdTag.create(**chatIdTag.model_dump())
result = ChatIdTag(**chatIdTag.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return chatIdTag
return ChatIdTagModel.model_validate(result)
else:
return None
except:
return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where(ChatIdTag.user_id == user_id)
.order_by(ChatIdTag.timestamp.desc())
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
self, db: Session, chat_id: str, user_id: str
) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
.order_by(ChatIdTag.timestamp.desc())
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> Optional[ChatIdTagModel]:
self, db: Session, tag_name: str, user_id: str
) -> List[ChatIdTagModel]:
return [
ChatIdTagModel(**model_to_dict(chat_id_tag))
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name))
.order_by(ChatIdTag.timestamp.desc())
ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
self, db: Session, tag_name: str, user_id: str
) -> int:
return (
ChatIdTag.select()
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id))
.count()
)
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
def delete_tag_by_tag_name_and_user_id(
self, db: Session, tag_name: str, user_id: str
) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
db, tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str
self, db: Session, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name)
& (ChatIdTag.chat_id == chat_id)
& (ChatIdTag.user_id == user_id)
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
db, tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
def delete_tags_by_chat_id_and_user_id(
self, db: Session, chat_id: str, user_id: str
) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id)
for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
tag.tag_name, chat_id, user_id
db, tag.tag_name, chat_id, user_id
)
return True
Tags = TagTable(DB)
Tags = TagTable()

View File

@@ -1,10 +1,11 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, JSONField
from apps.webui.models.users import Users
import json
@@ -21,19 +22,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Tool(Base):
__tablename__ = "tool"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(String)
content = Column(String)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ToolMeta(BaseModel):
@@ -51,6 +51,8 @@ class ToolModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
@@ -78,12 +80,9 @@ class ToolValves(BaseModel):
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
@@ -96,24 +95,27 @@ class ToolsTable:
)
try:
result = Tool.create(**tool.model_dump())
result = Tool(**tool.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return tool
return ToolModel.model_validate(result)
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def get_tools(self, db: Session) -> List[ToolModel]:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
@@ -180,25 +182,19 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())}
)
return self.get_tool_by_id(db, id)
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
def delete_tool_by_id(self, db: Session, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(Tool).filter_by(id=id).delete()
return True
except:
return False
Tools = ToolsTable(DB)
Tools = ToolsTable()

View File

@@ -1,11 +1,13 @@
from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict, parse_obj_as
from typing import List, Union, Optional
import time
from sqlalchemy import String, Column, BigInteger, Text
from sqlalchemy.orm import Session
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB, JSONField
from apps.webui.internal.db import Base, JSONField
from apps.webui.models.chats import Chats
####################
@@ -13,25 +15,24 @@ from apps.webui.models.chats import Chats
####################
class User(Model):
id = CharField(unique=True)
name = CharField()
email = CharField()
role = CharField()
profile_image_url = TextField()
class User(Base):
__tablename__ = "user"
last_active_at = BigIntegerField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
id = Column(String, primary_key=True)
name = Column(String)
email = Column(String)
role = Column(String)
profile_image_url = Column(String)
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
info = JSONField(null=True)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
oauth_sub = TextField(null=True, unique=True)
api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
class Meta:
database = DB
oauth_sub = Column(Text, unique=True)
class UserSettings(BaseModel):
@@ -41,6 +42,8 @@ class UserSettings(BaseModel):
class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
name: str
email: str
@@ -76,12 +79,10 @@ class UserUpdateForm(BaseModel):
class UsersTable:
def __init__(self, db):
self.db = db
self.db.create_tables([User])
def insert_new_user(
self,
db: Session,
id: str,
name: str,
email: str,
@@ -102,30 +103,33 @@ class UsersTable:
"oauth_sub": oauth_sub,
}
)
result = User.create(**user.model_dump())
result = User(**user.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return user
else:
return None
def get_user_by_id(self, id: str) -> Optional[UserModel]:
def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]:
try:
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception as e:
return None
def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]:
try:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except:
return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]:
try:
user = User.get(User.api_key == api_key)
return UserModel(**model_to_dict(user))
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
user = User.get(User.email == email)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except:
return None
@@ -136,88 +140,94 @@ class UsersTable:
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))
for user in User.select()
# .limit(limit).offset(skip)
]
def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]:
users = (
db.query(User)
# .offset(skip).limit(limit)
.all()
)
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]:
return User.select().count()
def get_num_users(self, db: Session) -> Optional[int]:
return db.query(User).count()
def get_first_user(self) -> UserModel:
def get_first_user(self, db: Session) -> UserModel:
try:
user = User.select().order_by(User.created_at).first()
return UserModel(**model_to_dict(user))
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
def update_user_role_by_id(
self, db: Session, id: str, role: str
) -> Optional[UserModel]:
try:
query = User.update(role=role).where(User.id == id)
query.execute()
db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str
self, db: Session, id: str, profile_image_url: str
) -> Optional[UserModel]:
try:
query = User.update(profile_image_url=profile_image_url).where(
User.id == id
db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url}
)
query.execute()
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
def update_user_last_active_by_id(
self, db: Session, id: str
) -> Optional[UserModel]:
try:
query = User.update(last_active_at=int(time.time())).where(User.id == id)
query.execute()
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
self, db: Session, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
query.execute()
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
def update_user_by_id(
self, db: Session, id: str, updated: dict
) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
query.execute()
db.query(User).filter_by(id=id).update(updated)
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
return None
def delete_user_by_id(self, id: str) -> bool:
def delete_user_by_id(self, db: Session, id: str) -> bool:
try:
# Delete User Chats
result = Chats.delete_chats_by_user_id(id)
result = Chats.delete_chats_by_user_id(db, id)
if result:
# Delete User
query = User.delete().where(User.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(User).filter_by(id=id).delete()
db.commit()
return True
else:
@@ -225,21 +235,20 @@ class UsersTable:
except:
return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str:
try:
query = User.update(api_key=api_key).where(User.id == id)
result = query.execute()
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]:
try:
user = User.get(User.id == id)
user = db.query(User).filter_by(id=id).first()
return user.api_key
except:
except Exception as e:
return None
Users = UsersTable(DB)
Users = UsersTable()