feat(sqlalchemy): use session factory instead of context manager

This commit is contained in:
Jonathan Rohde 2024-06-24 13:06:15 +02:00
parent eb01e8d275
commit da403f3e3c
15 changed files with 640 additions and 759 deletions

View File

@ -57,14 +57,4 @@ SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
) )
Base = declarative_base() Base = declarative_base()
Session = scoped_session(SessionLocal)
@contextmanager
def get_session():
session = scoped_session(SessionLocal)
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e

View File

@ -3,12 +3,11 @@ from typing import Optional
import uuid import uuid
import logging import logging
from sqlalchemy import String, Column, Boolean from sqlalchemy import String, Column, Boolean
from sqlalchemy.orm import Session
from apps.webui.models.users import UserModel, Users from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password from utils.utils import verify_password
from apps.webui.internal.db import Base, get_session from apps.webui.internal.db import Base, Session
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -103,101 +102,93 @@ class AuthsTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db: log.info("insert_new_auth")
log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
auth = AuthModel( auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True} **{"id": id, "email": email, "password": password, "active": True}
) )
result = Auth(**auth.model_dump()) result = Auth(**auth.model_dump())
db.add(result) Session.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth_sub)
)
db.commit() Session.commit()
db.refresh(result) Session.refresh(result)
if result and user: if result and user:
return user return user
else: else:
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
with get_session() as db: try:
try: auth = Session.query(Auth).filter_by(email=email, active=True).first()
auth = db.query(Auth).filter_by(email=email, active=True).first() if auth:
if auth: if verify_password(password, auth.password):
if verify_password(password, auth.password): user = Users.get_user_by_id(auth.id)
user = Users.get_user_by_id(auth.id) return user
return user
else:
return None
else: else:
return None return None
except: else:
return None return None
except:
return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}") log.info(f"authenticate_user_by_api_key: {api_key}")
with get_session() as db: # if no api_key, return None
# if no api_key, return None if not api_key:
if not api_key: return None
return None
try: try:
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(api_key)
return user if user else None return user if user else None
except: except:
return False return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}") log.info(f"authenticate_user_by_trusted_header: {email}")
with get_session() as db: try:
try: auth = Session.query(Auth).filter(email=email, active=True).first()
auth = db.query(Auth).filter(email=email, active=True).first() if auth:
if auth: user = Users.get_user_by_id(auth.id)
user = Users.get_user_by_id(auth.id) return user
return user except:
except: return None
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
with get_session() as db: try:
try: result = (
result = ( Session.query(Auth).filter_by(id=id).update({"password": new_password})
db.query(Auth).filter_by(id=id).update({"password": new_password}) )
) return True if result == 1 else False
return True if result == 1 else False except:
except: return False
return False
def update_email_by_id(self, id: str, email: str) -> bool: def update_email_by_id(self, id: str, email: str) -> bool:
with get_session() as db: try:
try: result = Session.query(Auth).filter_by(id=id).update({"email": email})
result = db.query(Auth).filter_by(id=id).update({"email": email}) return True if result == 1 else False
return True if result == 1 else False except:
except: return False
return False
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
with get_session() as db: try:
try: # Delete User
# Delete User result = Users.delete_user_by_id(id)
result = Users.delete_user_by_id(id)
if result: if result:
db.query(Auth).filter_by(id=id).delete() Session.query(Auth).filter_by(id=id).delete()
return True return True
else: else:
return False
except:
return False return False
except:
return False
Auths = AuthsTable() Auths = AuthsTable()

View File

@ -6,9 +6,8 @@ import uuid
import time import time
from sqlalchemy import Column, String, BigInteger, Boolean from sqlalchemy import Column, String, BigInteger, Boolean
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, get_session from apps.webui.internal.db import Base, Session
#################### ####################
@ -80,93 +79,88 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_session() as db: id = str(uuid.uuid4())
id = str(uuid.uuid4()) chat = ChatModel(
chat = ChatModel( **{
**{ "id": id,
"id": id, "user_id": user_id,
"user_id": user_id, "title": (
"title": ( form_data.chat["title"]
form_data.chat["title"] if "title" in form_data.chat
if "title" in form_data.chat else "New Chat"
else "New Chat" ),
), "chat": json.dumps(form_data.chat),
"chat": json.dumps(form_data.chat), "created_at": int(time.time()),
"created_at": int(time.time()), "updated_at": int(time.time()),
"updated_at": int(time.time()), }
} )
)
result = Chat(**chat.model_dump()) result = Chat(**chat.model_dump())
db.add(result) Session.add(result)
db.commit() Session.commit()
db.refresh(result) Session.refresh(result)
return ChatModel.model_validate(result) if result else None return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
with get_session() as db: try:
try: chat_obj = Session.get(Chat, id)
chat_obj = db.get(Chat, id) chat_obj.chat = json.dumps(chat)
chat_obj.chat = json.dumps(chat) chat_obj.title = chat["title"] if "title" in chat else "New Chat"
chat_obj.title = chat["title"] if "title" in chat else "New Chat" chat_obj.updated_at = int(time.time())
chat_obj.updated_at = int(time.time()) Session.commit()
db.commit() Session.refresh(chat_obj)
db.refresh(chat_obj)
return ChatModel.model_validate(chat_obj) return ChatModel.model_validate(chat_obj)
except Exception as e: except Exception as e:
return None return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db: # Get the existing chat to share
# Get the existing chat to share chat = Session.get(Chat, chat_id)
chat = db.get(Chat, chat_id) # Check if the chat is already shared
# Check if the chat is already shared if chat.share_id:
if chat.share_id: return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
return self.get_chat_by_id_and_user_id(chat.share_id, "shared") # Create a new chat with the same data, but with a new ID
# Create a new chat with the same data, but with a new ID shared_chat = ChatModel(
shared_chat = ChatModel( **{
**{ "id": str(uuid.uuid4()),
"id": str(uuid.uuid4()), "user_id": f"shared-{chat_id}",
"user_id": f"shared-{chat_id}", "title": chat.title,
"title": chat.title, "chat": chat.chat,
"chat": chat.chat, "created_at": chat.created_at,
"created_at": chat.created_at, "updated_at": int(time.time()),
"updated_at": int(time.time()), }
} )
) shared_result = Chat(**shared_chat.model_dump())
shared_result = Chat(**shared_chat.model_dump()) Session.add(shared_result)
db.add(shared_result) Session.commit()
db.commit() Session.refresh(shared_result)
db.refresh(shared_result) # Update the original chat with the share_id
# Update the original chat with the share_id result = (
result = ( Session.query(Chat)
db.query(Chat) .filter_by(id=chat_id)
.filter_by(id=chat_id) .update({"share_id": shared_chat.id})
.update({"share_id": shared_chat.id}) )
)
return shared_chat if (shared_result and result) else None return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db: try:
try: print("update_shared_chat_by_id")
print("update_shared_chat_by_id") chat = Session.get(Chat, chat_id)
chat = db.get(Chat, chat_id) print(chat)
print(chat) chat.title = chat.title
chat.title = chat.title chat.chat = chat.chat
chat.chat = chat.chat Session.commit()
db.commit() Session.refresh(chat)
db.refresh(chat)
return self.get_chat_by_id(chat.share_id) return self.get_chat_by_id(chat.share_id)
except: except:
return None return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try: try:
with get_session() as db: Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
return True return True
except: except:
return False return False
@ -175,30 +169,27 @@ class ChatTable:
self, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
with get_session() as db: chat = Session.get(Chat, id)
chat = db.get(Chat, id) chat.share_id = share_id
chat.share_id = share_id Session.commit()
db.commit() Session.refresh(chat)
db.refresh(chat) return ChatModel.model_validate(chat)
return chat
except: except:
return None return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_session() as db: chat = Session.get(Chat, id)
chat = self.get_chat_by_id(id) chat.archived = not chat.archived
db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) Session.commit()
Session.refresh(chat)
return self.get_chat_by_id(id) return ChatModel.model_validate(chat)
except: except:
return None return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool: def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db: Session.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
return True return True
except: except:
return False return False
@ -206,9 +197,8 @@ class ChatTable:
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_session() as db:
all_chats = ( all_chats = (
db.query(Chat) Session.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
@ -223,120 +213,108 @@ class ChatTable:
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_session() as db: query = Session.query(Chat).filter_by(user_id=user_id)
query = db.query(Chat).filter_by(user_id=user_id) if not include_archived:
if not include_archived: query = query.filter_by(archived=False)
query = query.filter_by(archived=False) all_chats = (
all_chats = ( query.order_by(Chat.updated_at.desc())
query.order_by(Chat.updated_at.desc()) # .limit(limit).offset(skip)
# .limit(limit).offset(skip) .all()
.all() )
) return [ChatModel.model_validate(chat) for chat in all_chats]
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_session() as db: all_chats = (
all_chats = ( Session.query(Chat)
db.query(Chat) .filter(Chat.id.in_(chat_ids))
.filter(Chat.id.in_(chat_ids)) .filter_by(archived=False)
.filter_by(archived=False) .order_by(Chat.updated_at.desc())
.order_by(Chat.updated_at.desc()) .all()
.all() )
) return [ChatModel.model_validate(chat) for chat in all_chats]
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_session() as db: chat = Session.get(Chat, id)
chat = db.get(Chat, id) return ChatModel.model_validate(chat)
return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_session() as db: chat = Session.query(Chat).filter_by(share_id=id).first()
chat = db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
return self.get_chat_by_id(id) return self.get_chat_by_id(id)
else: else:
return None return None
except Exception as e: except Exception as e:
return None return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try: try:
with get_session() as db: chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first()
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat)
return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
with get_session() as db: all_chats = (
all_chats = ( Session.query(Chat)
db.query(Chat) # .limit(limit).offset(skip)
# .limit(limit).offset(skip) .order_by(Chat.updated_at.desc())
.order_by(Chat.updated_at.desc()) )
) return [ChatModel.model_validate(chat) for chat in all_chats]
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db: all_chats = (
all_chats = ( Session.query(Chat)
db.query(Chat) .filter_by(user_id=user_id)
.filter_by(user_id=user_id) .order_by(Chat.updated_at.desc())
.order_by(Chat.updated_at.desc()) )
) return [ChatModel.model_validate(chat) for chat in all_chats]
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db: all_chats = (
all_chats = ( Session.query(Chat)
db.query(Chat) .filter_by(user_id=user_id, archived=True)
.filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc())
.order_by(Chat.updated_at.desc()) )
) return [ChatModel.model_validate(chat) for chat in all_chats]
return [ChatModel.model_validate(chat) for chat in all_chats]
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
with get_session() as db: Session.query(Chat).filter_by(id=id).delete()
db.query(Chat).filter_by(id=id).delete()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
return False return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
with get_session() as db: Session.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
return False return False
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db: self.delete_shared_chats_by_user_id(user_id)
self.delete_shared_chats_by_user_id(user_id)
db.query(Chat).filter_by(user_id=user_id).delete() Session.query(Chat).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db: chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all()
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
return True return True
except: except:

View File

@ -4,9 +4,8 @@ import time
import logging import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, get_session from apps.webui.internal.db import Base, Session
import json import json
@ -84,46 +83,42 @@ class DocumentsTable:
) )
try: try:
with get_session() as db: result = Document(**document.model_dump())
result = Document(**document.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return DocumentModel.model_validate(result)
return DocumentModel.model_validate(result) else:
else: return None
return None
except: except:
return None return None
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try: try:
with get_session() as db: document = Session.query(Document).filter_by(name=name).first()
document = db.query(Document).filter_by(name=name).first() return DocumentModel.model_validate(document) if document else None
return DocumentModel.model_validate(document) if document else None
except: except:
return None return None
def get_docs(self) -> List[DocumentModel]: def get_docs(self) -> List[DocumentModel]:
with get_session() as db: return [
return [ DocumentModel.model_validate(doc) for doc in Session.query(Document).all()
DocumentModel.model_validate(doc) for doc in db.query(Document).all() ]
]
def update_doc_by_name( def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
with get_session() as db: Session.query(Document).filter_by(name=name).update(
db.query(Document).filter_by(name=name).update( {
{ "title": form_data.title,
"title": form_data.title, "name": form_data.name,
"name": form_data.name, "timestamp": int(time.time()),
"timestamp": int(time.time()), }
} )
) Session.commit()
db.commit() return self.get_doc_by_name(form_data.name)
return self.get_doc_by_name(form_data.name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
@ -132,27 +127,25 @@ class DocumentsTable:
self, name: str, updated: dict self, name: str, updated: dict
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
with get_session() as db: doc = self.get_doc_by_name(name)
doc = self.get_doc_by_name(name) doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = json.loads(doc.content if doc.content else "{}") doc_content = {**doc_content, **updated}
doc_content = {**doc_content, **updated}
db.query(Document).filter_by(name=name).update( Session.query(Document).filter_by(name=name).update(
{ {
"content": json.dumps(doc_content), "content": json.dumps(doc_content),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} }
) )
db.commit() Session.commit()
return self.get_doc_by_name(name) return self.get_doc_by_name(name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def delete_doc_by_name(self, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:
try: try:
with get_session() as db: Session.query(Document).filter_by(name=name).delete()
db.query(Document).filter_by(name=name).delete()
return True return True
except: except:
return False return False

View File

@ -4,9 +4,8 @@ import time
import logging import logging
from sqlalchemy import Column, String, BigInteger from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base, get_session from apps.webui.internal.db import JSONField, Base, Session
import json import json
@ -71,45 +70,38 @@ class FilesTable:
) )
try: try:
with get_session() as db: result = File(**file.model_dump())
result = File(**file.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return FileModel.model_validate(result)
return FileModel.model_validate(result) else:
else: return None
return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_file_by_id(self, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str) -> Optional[FileModel]:
try: try:
with get_session() as db: file = Session.get(File, id)
file = db.get(File, id) return FileModel.model_validate(file)
return FileModel.model_validate(file)
except: except:
return None return None
def get_files(self) -> List[FileModel]: def get_files(self) -> List[FileModel]:
with get_session() as db: return [FileModel.model_validate(file) for file in Session.query(File).all()]
return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, id: str) -> bool: def delete_file_by_id(self, id: str) -> bool:
try: try:
with get_session() as db: Session.query(File).filter_by(id=id).delete()
db.query(File).filter_by(id=id).delete()
db.commit()
return True return True
except: except:
return False return False
def delete_all_files(self) -> bool: def delete_all_files(self) -> bool:
try: try:
with get_session() as db: Session.query(File).delete()
db.query(File).delete()
db.commit()
return True return True
except: except:
return False return False

View File

@ -4,9 +4,8 @@ import time
import logging import logging
from sqlalchemy import Column, String, Text, BigInteger, Boolean from sqlalchemy import Column, String, Text, BigInteger, Boolean
from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base, get_session from apps.webui.internal.db import JSONField, Base, Session
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
@ -100,64 +99,57 @@ class FunctionsTable:
) )
try: try:
with get_session() as db: result = Function(**function.model_dump())
result = Function(**function.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return FunctionModel.model_validate(result)
return FunctionModel.model_validate(result) else:
else: return None
return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_function_by_id(self, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
with get_session() as db: function = Session.get(Function, id)
function = db.get(Function, id) return FunctionModel.model_validate(function)
return FunctionModel.model_validate(function)
except: except:
return None return None
def get_functions(self, active_only=False) -> List[FunctionModel]: def get_functions(self, active_only=False) -> List[FunctionModel]:
if active_only: if active_only:
with get_session() as db: return [
return [ FunctionModel.model_validate(function)
FunctionModel.model_validate(function) for function in Session.query(Function).filter_by(is_active=True).all()
for function in db.query(Function).filter_by(is_active=True).all() ]
]
else: else:
with get_session() as db: return [
return [ FunctionModel.model_validate(function)
FunctionModel.model_validate(function) for function in Session.query(Function).all()
for function in db.query(Function).all() ]
]
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> List[FunctionModel]: ) -> List[FunctionModel]:
if active_only: if active_only:
with get_session() as db: return [
return [ FunctionModel.model_validate(function)
FunctionModel.model_validate(function) for function in Session.query(Function)
for function in db.query(Function) .filter_by(type=type, is_active=True)
.filter_by(type=type, is_active=True) .all()
.all() ]
]
else: else:
with get_session() as db: return [
return [ FunctionModel.model_validate(function)
FunctionModel.model_validate(function) for function in Session.query(Function).filter_by(type=type).all()
for function in db.query(Function).filter_by(type=type).all() ]
]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
with get_session() as db: function = Session.get(Function, id)
function = db.get(Function, id) return function.valves if function.valves else {}
return function.valves if function.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
@ -166,12 +158,12 @@ class FunctionsTable:
self, id: str, valves: dict self, id: str, valves: dict
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
try: try:
with get_session() as db: function = Session.get(Function, id)
db.query(Function).filter_by(id=id).update( function.valves = valves
{"valves": valves, "updated_at": int(time.time())} function.updated_at = int(time.time())
) Session.commit()
db.commit() Session.refresh(function)
return self.get_function_by_id(id) return self.get_function_by_id(id)
except: except:
return None return None
@ -219,36 +211,33 @@ class FunctionsTable:
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try: try:
with get_session() as db: Session.query(Function).filter_by(id=id).update(
db.query(Function).filter_by(id=id).update( {
{ **updated,
**updated, "updated_at": int(time.time()),
"updated_at": int(time.time()), }
} )
) Session.commit()
db.commit() return self.get_function_by_id(id)
return self.get_function_by_id(id)
except: except:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self) -> Optional[bool]:
try: try:
with get_session() as db: Session.query(Function).update(
db.query(Function).update( {
{ "is_active": False,
"is_active": False, "updated_at": int(time.time()),
"updated_at": int(time.time()), }
} )
) Session.commit()
db.commit()
return True return True
except: except:
return None return None
def delete_function_by_id(self, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
try: try:
with get_session() as db: Session.query(Function).filter_by(id=id).delete()
db.query(Function).filter_by(id=id).delete()
return True return True
except: except:
return False return False

View File

@ -2,10 +2,8 @@ from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import List, Union, Optional
from sqlalchemy import Column, String, BigInteger from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, get_session from apps.webui.internal.db import Base, Session
from apps.webui.models.chats import Chats
import time import time
import uuid import uuid
@ -58,15 +56,14 @@ class MemoriesTable:
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
with get_session() as db: result = Memory(**memory.model_dump())
result = Memory(**memory.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return MemoryModel.model_validate(result)
return MemoryModel.model_validate(result) else:
else: return None
return None
def update_memory_by_id( def update_memory_by_id(
self, self,
@ -74,62 +71,55 @@ class MemoriesTable:
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
try: try:
with get_session() as db: Session.query(Memory).filter_by(id=id).update(
db.query(Memory).filter_by(id=id).update( {"content": content, "updated_at": int(time.time())}
{"content": content, "updated_at": int(time.time())} )
) Session.commit()
db.commit() return self.get_memory_by_id(id)
return self.get_memory_by_id(id)
except: except:
return None return None
def get_memories(self) -> List[MemoryModel]: def get_memories(self) -> List[MemoryModel]:
try: try:
with get_session() as db: memories = Session.query(Memory).all()
memories = db.query(Memory).all() return [MemoryModel.model_validate(memory) for memory in memories]
return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
try: try:
with get_session() as db: memories = Session.query(Memory).filter_by(user_id=user_id).all()
memories = db.query(Memory).filter_by(user_id=user_id).all() return [MemoryModel.model_validate(memory) for memory in memories]
return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
try: try:
with get_session() as db: memory = Session.get(Memory, id)
memory = db.get(Memory, id) return MemoryModel.model_validate(memory)
return MemoryModel.model_validate(memory)
except: except:
return None return None
def delete_memory_by_id(self, id: str) -> bool: def delete_memory_by_id(self, id: str) -> bool:
try: try:
with get_session() as db: Session.query(Memory).filter_by(id=id).delete()
db.query(Memory).filter_by(id=id).delete()
return True return True
except: except:
return False return False
def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool: def delete_memories_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db: Session.query(Memory).filter_by(user_id=user_id).delete()
db.query(Memory).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_memory_by_id_and_user_id( def delete_memory_by_id_and_user_id(
self, db: Session, id: str, user_id: str self, id: str, user_id: str
) -> bool: ) -> bool:
try: try:
with get_session() as db: Session.query(Memory).filter_by(id=id, user_id=user_id).delete()
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
return True return True
except: except:
return False return False

View File

@ -4,9 +4,8 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, JSONField, get_session from apps.webui.internal.db import Base, JSONField, Session
from typing import List, Union, Optional from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -127,41 +126,37 @@ class ModelsTable:
} }
) )
try: try:
with get_session() as db: result = Model(**model.model_dump())
result = Model(**model.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result)
if result: if result:
return ModelModel.model_validate(result) return ModelModel.model_validate(result)
else: else:
return None return None
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
with get_session() as db: return [ModelModel.model_validate(model) for model in Session.query(Model).all()]
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_session() as db: model = Session.get(Model, id)
model = db.get(Model, id) return ModelModel.model_validate(model)
return ModelModel.model_validate(model)
except: except:
return None return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try: try:
# update only the fields that are present in the model # update only the fields that are present in the model
with get_session() as db: model = Session.query(Model).get(id)
model = db.query(Model).get(id) model.update(**model.model_dump())
model.update(**model.model_dump()) Session.commit()
db.commit() Session.refresh(model)
db.refresh(model) return ModelModel.model_validate(model)
return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
print(e) print(e)
@ -169,8 +164,7 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
with get_session() as db: Session.query(Model).filter_by(id=id).delete()
db.query(Model).filter_by(id=id).delete()
return True return True
except: except:
return False return False

View File

@ -3,9 +3,8 @@ from typing import List, Optional
import time import time
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, get_session from apps.webui.internal.db import Base, Session
import json import json
@ -50,65 +49,59 @@ class PromptsTable:
def insert_new_prompt( def insert_new_prompt(
self, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
with get_session() as db: prompt = PromptModel(
prompt = PromptModel( **{
**{ "user_id": user_id,
"user_id": user_id, "command": form_data.command,
"command": form_data.command, "title": form_data.title,
"title": form_data.title, "content": form_data.content,
"content": form_data.content, "timestamp": int(time.time()),
"timestamp": int(time.time()), }
} )
)
try: try:
result = Prompt(**prompt.dict()) result = Prompt(**prompt.dict())
db.add(result) Session.add(result)
db.commit() Session.commit()
db.refresh(result) Session.refresh(result)
if result: if result:
return PromptModel.model_validate(result) return PromptModel.model_validate(result)
else: else:
return None
except Exception as e:
return None return None
except Exception as e:
return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
with get_session() as db: try:
try: prompt = Session.query(Prompt).filter_by(command=command).first()
prompt = db.query(Prompt).filter_by(command=command).first() return PromptModel.model_validate(prompt)
return PromptModel.model_validate(prompt) except:
except: return None
return None
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
with get_session() as db: return [
return [ PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all()
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() ]
]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
with get_session() as db: try:
try: prompt = Session.query(Prompt).filter_by(command=command).first()
prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title
prompt.title = form_data.title prompt.content = form_data.content
prompt.content = form_data.content prompt.timestamp = int(time.time())
prompt.timestamp = int(time.time()) Session.commit()
db.commit() return PromptModel.model_validate(prompt)
return prompt except:
# return self.get_prompt_by_command(command) return None
except:
return None
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
with get_session() as db: try:
try: Session.query(Prompt).filter_by(command=command).delete()
db.query(Prompt).filter_by(command=command).delete() return True
return True except:
except: return False
return False
Prompts = PromptsTable() Prompts = PromptsTable()

View File

@ -7,9 +7,8 @@ import time
import logging import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, get_session from apps.webui.internal.db import Base, Session
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -83,15 +82,14 @@ class TagTable:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
with get_session() as db: result = Tag(**tag.model_dump())
result = Tag(**tag.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return TagModel.model_validate(result)
return TagModel.model_validate(result) else:
else: return None
return None
except Exception as e: except Exception as e:
return None return None
@ -99,9 +97,8 @@ class TagTable:
self, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
with get_session() as db: tag = Session.query(Tag).filter(name=name, user_id=user_id).first()
tag = db.query(Tag).filter(name=name, user_id=user_id).first() return TagModel.model_validate(tag)
return TagModel.model_validate(tag)
except Exception as e: except Exception as e:
return None return None
@ -123,105 +120,99 @@ class TagTable:
} }
) )
try: try:
with get_session() as db: result = ChatIdTag(**chatIdTag.model_dump())
result = ChatIdTag(**chatIdTag.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return ChatIdTagModel.model_validate(result)
return ChatIdTagModel.model_validate(result) else:
else: return None
return None
except: except:
return None return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
with get_session() as db: tag_names = [
tag_names = [ chat_id_tag.tag_name
chat_id_tag.tag_name for chat_id_tag in (
for chat_id_tag in ( Session.query(ChatIdTag)
db.query(ChatIdTag) .filter_by(user_id=user_id)
.filter_by(user_id=user_id) .order_by(ChatIdTag.timestamp.desc())
.order_by(ChatIdTag.timestamp.desc()) .all()
.all() )
) ]
]
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (
db.query(Tag) Session.query(Tag)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names)) .filter(Tag.name.in_(tag_names))
.all() .all()
) )
] ]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
with get_session() as db: tag_names = [
tag_names = [ chat_id_tag.tag_name
chat_id_tag.tag_name for chat_id_tag in (
for chat_id_tag in ( Session.query(ChatIdTag)
db.query(ChatIdTag) .filter_by(user_id=user_id, chat_id=chat_id)
.filter_by(user_id=user_id, chat_id=chat_id) .order_by(ChatIdTag.timestamp.desc())
.order_by(ChatIdTag.timestamp.desc()) .all()
.all() )
) ]
]
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (
db.query(Tag) Session.query(Tag)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names)) .filter(Tag.name.in_(tag_names))
.all() .all()
) )
] ]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> List[ChatIdTagModel]: ) -> List[ChatIdTagModel]:
with get_session() as db: return [
return [ ChatIdTagModel.model_validate(chat_id_tag)
ChatIdTagModel.model_validate(chat_id_tag) for chat_id_tag in (
for chat_id_tag in ( Session.query(ChatIdTag)
db.query(ChatIdTag) .filter_by(user_id=user_id, tag_name=tag_name)
.filter_by(user_id=user_id, tag_name=tag_name) .order_by(ChatIdTag.timestamp.desc())
.order_by(ChatIdTag.timestamp.desc()) .all()
.all() )
) ]
]
def count_chat_ids_by_tag_name_and_user_id( def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
with get_session() as db: return (
return ( Session.query(ChatIdTag)
db.query(ChatIdTag) .filter_by(tag_name=tag_name, user_id=user_id)
.filter_by(tag_name=tag_name, user_id=user_id) .count()
.count() )
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try: try:
with get_session() as db: res = (
res = ( Session.query(ChatIdTag)
db.query(ChatIdTag) .filter_by(tag_name=tag_name, user_id=user_id)
.filter_by(tag_name=tag_name, user_id=user_id) .delete()
.delete() )
) log.debug(f"res: {res}")
log.debug(f"res: {res}") Session.commit()
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id tag_name, user_id
) )
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
@ -231,21 +222,20 @@ class TagTable:
self, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
try: try:
with get_session() as db: res = (
res = ( Session.query(ChatIdTag)
db.query(ChatIdTag) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) .delete()
.delete() )
) log.debug(f"res: {res}")
log.debug(f"res: {res}") Session.commit()
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id tag_name, user_id
) )
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True return True
except Exception as e: except Exception as e:

View File

@ -3,9 +3,8 @@ from typing import List, Optional
import time import time
import logging import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, JSONField, get_session from apps.webui.internal.db import Base, JSONField, Session
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
@ -95,48 +94,43 @@ class ToolsTable:
) )
try: try:
with get_session() as db: result = Tool(**tool.model_dump())
result = Tool(**tool.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return ToolModel.model_validate(result)
return ToolModel.model_validate(result) else:
else: return None
return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
with get_session() as db: tool = Session.get(Tool, id)
tool = db.get(Tool, id) return ToolModel.model_validate(tool)
return ToolModel.model_validate(tool)
except: except:
return None return None
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
with get_session() as db: return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()]
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
with get_session() as db: tool = Session.get(Tool, id)
tool = db.get(Tool, id) return tool.valves if tool.valves else {}
return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
with get_session() as db: Session.query(Tool).filter_by(id=id).update(
db.query(Tool).filter_by(id=id).update( {"valves": valves, "updated_at": int(time.time())}
{"valves": valves, "updated_at": int(time.time())} )
) Session.commit()
db.commit() return self.get_tool_by_id(id)
return self.get_tool_by_id(id)
except: except:
return None return None
@ -183,19 +177,18 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
with get_session() as db: tool = Session.get(Tool, id)
db.query(Tool).filter_by(id=id).update( tool.update(**updated)
{**updated, "updated_at": int(time.time())} tool.updated_at = int(time.time())
) Session.commit()
db.commit() Session.refresh(tool)
return self.get_tool_by_id(id) return ToolModel.model_validate(tool)
except: except:
return None return None
def delete_tool_by_id(self, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
with get_session() as db: Session.query(Tool).filter_by(id=id).delete()
db.query(Tool).filter_by(id=id).delete()
return True return True
except: except:
return False return False

View File

@ -3,11 +3,10 @@ from typing import List, Union, Optional
import time import time
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from sqlalchemy.orm import Session
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.webui.internal.db import Base, JSONField, get_session from apps.webui.internal.db import Base, JSONField, Session
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
@ -89,177 +88,161 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db: user = UserModel(
user = UserModel( **{
**{ "id": id,
"id": id, "name": name,
"name": name, "email": email,
"email": email, "role": role,
"role": role, "profile_image_url": profile_image_url,
"profile_image_url": profile_image_url, "last_active_at": int(time.time()),
"last_active_at": int(time.time()), "created_at": int(time.time()),
"created_at": int(time.time()), "updated_at": int(time.time()),
"updated_at": int(time.time()), "oauth_sub": oauth_sub,
"oauth_sub": oauth_sub, }
} )
) result = User(**user.model_dump())
result = User(**user.model_dump()) Session.add(result)
db.add(result) Session.commit()
db.commit() Session.refresh(result)
db.refresh(result) if result:
if result: return user
return user else:
else: return None
return None
def get_user_by_id(self, id: str) -> Optional[UserModel]: def get_user_by_id(self, id: str) -> Optional[UserModel]:
with get_session() as db: try:
try: user = Session.query(User).filter_by(id=id).first()
user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user)
return UserModel.model_validate(user) except Exception as e:
except Exception as e: return None
return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
with get_session() as db: try:
try: user = Session.query(User).filter_by(api_key=api_key).first()
user = db.query(User).filter_by(api_key=api_key).first() return UserModel.model_validate(user)
return UserModel.model_validate(user) except:
except: return None
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
with get_session() as db: try:
try: user = Session.query(User).filter_by(email=email).first()
user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user)
return UserModel.model_validate(user) except:
except: return None
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
with get_session() as db: try:
try: user = Session.query(User).filter_by(oauth_sub=sub).first()
user = db.query(User).filter_by(oauth_sub=sub).first() return UserModel.model_validate(user)
return UserModel.model_validate(user) except:
except: return None
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
with get_session() as db: users = (
users = ( Session.query(User)
db.query(User) # .offset(skip).limit(limit)
# .offset(skip).limit(limit) .all()
.all() )
) return [UserModel.model_validate(user) for user in users]
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
with get_session() as db: return Session.query(User).count()
return db.query(User).count()
def get_first_user(self) -> UserModel: def get_first_user(self) -> UserModel:
with get_session() as db: try:
try: user = Session.query(User).order_by(User.created_at).first()
user = db.query(User).order_by(User.created_at).first() return UserModel.model_validate(user)
return UserModel.model_validate(user) except:
except: return None
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
with get_session() as db: try:
try: Session.query(User).filter_by(id=id).update({"role": role})
db.query(User).filter_by(id=id).update({"role": role}) Session.commit()
db.commit()
user = db.query(User).filter_by(id=id).first() user = Session.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_profile_image_url_by_id( def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db: try:
try: Session.query(User).filter_by(id=id).update(
db.query(User).filter_by(id=id).update( {"profile_image_url": profile_image_url}
{"profile_image_url": profile_image_url} )
) Session.commit()
db.commit()
user = db.query(User).filter_by(id=id).first() user = Session.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
with get_session() as db: try:
try: Session.query(User).filter_by(id=id).update(
db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())}
{"last_active_at": int(time.time())} )
)
user = db.query(User).filter_by(id=id).first() user = Session.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_oauth_sub_by_id( def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_session() as db: try:
try: Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
user = db.query(User).filter_by(id=id).first() user = Session.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
with get_session() as db: try:
try: Session.query(User).filter_by(id=id).update(updated)
db.query(User).filter_by(id=id).update(updated) Session.commit()
db.commit()
user = db.query(User).filter_by(id=id).first() user = Session.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
# return UserModel(**user.dict()) # return UserModel(**user.dict())
except Exception as e: except Exception as e:
return None return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
with get_session() as db: try:
try: # Delete User Chats
# Delete User Chats result = Chats.delete_chats_by_user_id(id)
result = Chats.delete_chats_by_user_id(id)
if result: if result:
# Delete User # Delete User
db.query(User).filter_by(id=id).delete() Session.query(User).filter_by(id=id).delete()
db.commit() Session.commit()
return True return True
else: else:
return False
except:
return False return False
except:
return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str: def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
with get_session() as db: try:
try: result = Session.query(User).filter_by(id=id).update({"api_key": api_key})
result = db.query(User).filter_by(id=id).update({"api_key": api_key}) Session.commit()
db.commit() return True if result == 1 else False
return True if result == 1 else False except:
except: return False
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
with get_session() as db: try:
try: user = Session.query(User).filter_by(id=id).first()
user = db.query(User).filter_by(id=id).first() return user.api_key
return user.api_key except Exception as e:
except Exception as e: return None
return None
Users = UsersTable() Users = UsersTable()

View File

@ -29,7 +29,6 @@ from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
@ -57,7 +56,7 @@ from apps.webui.main import (
get_pipe_models, get_pipe_models,
generate_function_chat_completion, generate_function_chat_completion,
) )
from apps.webui.internal.db import get_session, SessionLocal from apps.webui.internal.db import Session, SessionLocal
from pydantic import BaseModel from pydantic import BaseModel
@ -794,6 +793,14 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
@app.middleware("http")
async def remove_session_after_request(request: Request, call_next):
response = await call_next(request)
log.debug("Removing session after request")
Session.commit()
Session.remove()
return response
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
@ -2034,8 +2041,7 @@ async def healthcheck():
@app.get("/health/db") @app.get("/health/db")
async def healthcheck_with_db(): async def healthcheck_with_db():
with get_session() as db: Session.execute(text("SELECT 1;")).all()
result = db.execute(text("SELECT 1;")).all()
return {"status": True} return {"status": True}

View File

@ -90,6 +90,8 @@ class TestChats(AbstractPostgresTest):
def test_get_user_archived_chats(self): def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id("2") self.chats.archive_all_chats_by_user_id("2")
from apps.webui.internal.db import Session
Session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived")) response = self.fast_api_client.get(self.create_url("/all/archived"))
assert response.status_code == 200 assert response.status_code == 200

View File

@ -9,6 +9,7 @@ from pytest_docker.plugin import get_docker_ip
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import text, create_engine from sqlalchemy import text, create_engine
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -50,11 +51,6 @@ class AbstractPostgresTest(AbstractIntegrationTest):
DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
docker_client: DockerClient docker_client: DockerClient
def get_db(self):
from apps.webui.internal.db import SessionLocal
return SessionLocal()
@classmethod @classmethod
def _create_db_url(cls, env_vars_postgres: dict) -> str: def _create_db_url(cls, env_vars_postgres: dict) -> str:
host = get_docker_ip() host = get_docker_ip()
@ -113,21 +109,21 @@ class AbstractPostgresTest(AbstractIntegrationTest):
pytest.fail(f"Could not setup test environment: {ex}") pytest.fail(f"Could not setup test environment: {ex}")
def _check_db_connection(self): def _check_db_connection(self):
from apps.webui.internal.db import Session
retries = 10 retries = 10
while retries > 0: while retries > 0:
try: try:
self.db_session.execute(text("SELECT 1")) Session.execute(text("SELECT 1"))
self.db_session.commit() Session.commit()
break break
except Exception as e: except Exception as e:
self.db_session.rollback() Session.rollback()
log.warning(e) log.warning(e)
time.sleep(3) time.sleep(3)
retries -= 1 retries -= 1
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
self.db_session = self.get_db()
self._check_db_connection() self._check_db_connection()
@classmethod @classmethod
@ -136,8 +132,9 @@ class AbstractPostgresTest(AbstractIntegrationTest):
cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
def teardown_method(self): def teardown_method(self):
from apps.webui.internal.db import Session
# rollback everything not yet committed # rollback everything not yet committed
self.db_session.commit() Session.commit()
# truncate all tables # truncate all tables
tables = [ tables = [
@ -152,5 +149,5 @@ class AbstractPostgresTest(AbstractIntegrationTest):
'"user"', '"user"',
] ]
for table in tables: for table in tables:
self.db_session.execute(text(f"TRUNCATE TABLE {table}")) Session.execute(text(f"TRUNCATE TABLE {table}"))
self.db_session.commit() Session.commit()