feat(sqlalchemy): format backend

This commit is contained in:
Jonathan Rohde
2024-06-24 09:57:08 +02:00
parent 320e658595
commit c134eab27a
21 changed files with 232 additions and 289 deletions

View File

@@ -126,9 +126,7 @@ class AuthsTable:
else:
return None
def authenticate_user(
self, email: str, password: str
) -> Optional[UserModel]:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
with get_session() as db:
try:
@@ -144,9 +142,7 @@ class AuthsTable:
except:
return None
def authenticate_user_by_api_key(
self, api_key: str
) -> Optional[UserModel]:
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
with get_session() as db:
# if no api_key, return None
@@ -159,9 +155,7 @@ class AuthsTable:
except:
return False
def authenticate_user_by_trusted_header(
self, email: str
) -> Optional[UserModel]:
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
with get_session() as db:
try:
@@ -172,12 +166,12 @@ class AuthsTable:
except:
return None
def update_user_password_by_id(
self, id: str, new_password: str
) -> bool:
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
with get_session() as db:
try:
result = db.query(Auth).filter_by(id=id).update({"password": new_password})
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
return True if result == 1 else False
except:
return False

View File

@@ -79,9 +79,7 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def insert_new_chat(
self, user_id: str, form_data: ChatForm
) -> Optional[ChatModel]:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_session() as db:
id = str(uuid.uuid4())
chat = ChatModel(
@@ -89,7 +87,9 @@ class ChatTable:
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": json.dumps(form_data.chat),
"created_at": int(time.time()),
@@ -103,9 +103,7 @@ class ChatTable:
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(
self, id: str, chat: dict
) -> Optional[ChatModel]:
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
with get_session() as db:
try:
chat_obj = db.get(Chat, id)
@@ -119,9 +117,7 @@ class ChatTable:
except Exception as e:
return None
def insert_shared_chat_by_chat_id(
self, chat_id: str
) -> Optional[ChatModel]:
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db:
# Get the existing chat to share
chat = db.get(Chat, chat_id)
@@ -145,14 +141,14 @@ class ChatTable:
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(
self, chat_id: str
) -> Optional[ChatModel]:
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db:
try:
print("update_shared_chat_by_id")
@@ -271,9 +267,7 @@ class ChatTable:
except Exception as e:
return None
def get_chat_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[ChatModel]:
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
with get_session() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
@@ -293,13 +287,13 @@ class ChatTable:
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db:
all_chats = (
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(
self, user_id: str
) -> List[ChatModel]:
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db:
all_chats = (
db.query(Chat)

View File

@@ -106,7 +106,9 @@ class DocumentsTable:
def get_docs(self) -> List[DocumentModel]:
with get_session() as db:
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
return [
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm

View File

@@ -39,6 +39,7 @@ class FileModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################

View File

@@ -142,9 +142,9 @@ class FunctionsTable:
with get_session() as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(
type=type, is_active=True
).all()
for function in db.query(Function)
.filter_by(type=type, is_active=True)
.all()
]
else:
with get_session() as db:
@@ -220,10 +220,12 @@ class FunctionsTable:
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try:
with get_session() as db:
db.query(Function).filter_by(id=id).update({
**updated,
"updated_at": int(time.time()),
})
db.query(Function).filter_by(id=id).update(
{
**updated,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_function_by_id(id)
except:
@@ -232,10 +234,12 @@ class FunctionsTable:
def deactivate_all_functions(self) -> Optional[bool]:
try:
with get_session() as db:
db.query(Function).update({
"is_active": False,
"updated_at": int(time.time()),
})
db.query(Function).update(
{
"is_active": False,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except:

View File

@@ -153,9 +153,7 @@ class ModelsTable:
except:
return None
def update_model_by_id(
self, id: str, model: ModelForm
) -> Optional[ModelModel]:
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
with get_session() as db:

View File

@@ -83,7 +83,9 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]:
with get_session() as db:
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
return [
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm

View File

@@ -79,9 +79,7 @@ class ChatTagsResponse(BaseModel):
class TagTable:
def insert_new_tag(
self, name: str, user_id: str
) -> Optional[TagModel]:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
@@ -201,11 +199,13 @@ class TagTable:
self, tag_name: str, user_id: str
) -> int:
with get_session() as db:
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> bool:
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try:
with get_session() as db:
res = (
@@ -252,9 +252,7 @@ class TagTable:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> bool:
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags:

View File

@@ -165,9 +165,7 @@ class UsersTable:
except:
return None
def update_user_role_by_id(
self, id: str, role: str
) -> Optional[UserModel]:
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update({"role": role})
@@ -193,12 +191,12 @@ class UsersTable:
except:
return None
def update_user_last_active_by_id(
self, id: str
) -> Optional[UserModel]:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
@@ -217,9 +215,7 @@ class UsersTable:
except:
return None
def update_user_by_id(
self, id: str, updated: dict
) -> Optional[UserModel]:
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update(updated)