mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat(sqlalchemy): format backend
This commit is contained in:
@@ -53,7 +53,9 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||
)
|
||||
else:
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -66,4 +68,3 @@ def get_session():
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
|
||||
@@ -126,9 +126,7 @@ class AuthsTable:
|
||||
else:
|
||||
return None
|
||||
|
||||
def authenticate_user(
|
||||
self, email: str, password: str
|
||||
) -> Optional[UserModel]:
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
with get_session() as db:
|
||||
try:
|
||||
@@ -144,9 +142,7 @@ class AuthsTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def authenticate_user_by_api_key(
|
||||
self, api_key: str
|
||||
) -> Optional[UserModel]:
|
||||
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_api_key: {api_key}")
|
||||
with get_session() as db:
|
||||
# if no api_key, return None
|
||||
@@ -159,9 +155,7 @@ class AuthsTable:
|
||||
except:
|
||||
return False
|
||||
|
||||
def authenticate_user_by_trusted_header(
|
||||
self, email: str
|
||||
) -> Optional[UserModel]:
|
||||
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_trusted_header: {email}")
|
||||
with get_session() as db:
|
||||
try:
|
||||
@@ -172,12 +166,12 @@ class AuthsTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_user_password_by_id(
|
||||
self, id: str, new_password: str
|
||||
) -> bool:
|
||||
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
|
||||
with get_session() as db:
|
||||
try:
|
||||
result = db.query(Auth).filter_by(id=id).update({"password": new_password})
|
||||
result = (
|
||||
db.query(Auth).filter_by(id=id).update({"password": new_password})
|
||||
)
|
||||
return True if result == 1 else False
|
||||
except:
|
||||
return False
|
||||
|
||||
@@ -79,9 +79,7 @@ class ChatTitleIdResponse(BaseModel):
|
||||
|
||||
class ChatTable:
|
||||
|
||||
def insert_new_chat(
|
||||
self, user_id: str, form_data: ChatForm
|
||||
) -> Optional[ChatModel]:
|
||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
@@ -89,7 +87,9 @@ class ChatTable:
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"title": (
|
||||
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
|
||||
form_data.chat["title"]
|
||||
if "title" in form_data.chat
|
||||
else "New Chat"
|
||||
),
|
||||
"chat": json.dumps(form_data.chat),
|
||||
"created_at": int(time.time()),
|
||||
@@ -103,9 +103,7 @@ class ChatTable:
|
||||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
|
||||
def update_chat_by_id(
|
||||
self, id: str, chat: dict
|
||||
) -> Optional[ChatModel]:
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
chat_obj = db.get(Chat, id)
|
||||
@@ -119,9 +117,7 @@ class ChatTable:
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def insert_shared_chat_by_chat_id(
|
||||
self, chat_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
# Get the existing chat to share
|
||||
chat = db.get(Chat, chat_id)
|
||||
@@ -145,14 +141,14 @@ class ChatTable:
|
||||
db.refresh(shared_result)
|
||||
# Update the original chat with the share_id
|
||||
result = (
|
||||
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
|
||||
db.query(Chat)
|
||||
.filter_by(id=chat_id)
|
||||
.update({"share_id": shared_chat.id})
|
||||
)
|
||||
|
||||
return shared_chat if (shared_result and result) else None
|
||||
|
||||
def update_shared_chat_by_chat_id(
|
||||
self, chat_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
print("update_shared_chat_by_id")
|
||||
@@ -271,9 +267,7 @@ class ChatTable:
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def get_chat_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_session() as db:
|
||||
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
|
||||
@@ -293,13 +287,13 @@ class ChatTable:
|
||||
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
with get_session() as db:
|
||||
all_chats = (
|
||||
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_archived_chats_by_user_id(
|
||||
self, user_id: str
|
||||
) -> List[ChatModel]:
|
||||
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
with get_session() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
|
||||
@@ -106,7 +106,9 @@ class DocumentsTable:
|
||||
|
||||
def get_docs(self) -> List[DocumentModel]:
|
||||
with get_session() as db:
|
||||
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
|
||||
return [
|
||||
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
|
||||
]
|
||||
|
||||
def update_doc_by_name(
|
||||
self, name: str, form_data: DocumentUpdateForm
|
||||
|
||||
@@ -39,6 +39,7 @@ class FileModel(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
@@ -142,9 +142,9 @@ class FunctionsTable:
|
||||
with get_session() as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function).filter_by(
|
||||
type=type, is_active=True
|
||||
).all()
|
||||
for function in db.query(Function)
|
||||
.filter_by(type=type, is_active=True)
|
||||
.all()
|
||||
]
|
||||
else:
|
||||
with get_session() as db:
|
||||
@@ -220,10 +220,12 @@ class FunctionsTable:
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
try:
|
||||
with get_session() as db:
|
||||
db.query(Function).filter_by(id=id).update({
|
||||
**updated,
|
||||
"updated_at": int(time.time()),
|
||||
})
|
||||
db.query(Function).filter_by(id=id).update(
|
||||
{
|
||||
**updated,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_function_by_id(id)
|
||||
except:
|
||||
@@ -232,10 +234,12 @@ class FunctionsTable:
|
||||
def deactivate_all_functions(self) -> Optional[bool]:
|
||||
try:
|
||||
with get_session() as db:
|
||||
db.query(Function).update({
|
||||
"is_active": False,
|
||||
"updated_at": int(time.time()),
|
||||
})
|
||||
db.query(Function).update(
|
||||
{
|
||||
"is_active": False,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return True
|
||||
except:
|
||||
|
||||
@@ -153,9 +153,7 @@ class ModelsTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_model_by_id(
|
||||
self, id: str, model: ModelForm
|
||||
) -> Optional[ModelModel]:
|
||||
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
||||
try:
|
||||
# update only the fields that are present in the model
|
||||
with get_session() as db:
|
||||
|
||||
@@ -83,7 +83,9 @@ class PromptsTable:
|
||||
|
||||
def get_prompts(self) -> List[PromptModel]:
|
||||
with get_session() as db:
|
||||
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
|
||||
return [
|
||||
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
|
||||
@@ -79,9 +79,7 @@ class ChatTagsResponse(BaseModel):
|
||||
|
||||
class TagTable:
|
||||
|
||||
def insert_new_tag(
|
||||
self, name: str, user_id: str
|
||||
) -> Optional[TagModel]:
|
||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||
id = str(uuid.uuid4())
|
||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||
try:
|
||||
@@ -201,11 +199,13 @@ class TagTable:
|
||||
self, tag_name: str, user_id: str
|
||||
) -> int:
|
||||
with get_session() as db:
|
||||
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
|
||||
return (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(tag_name=tag_name, user_id=user_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
def delete_tag_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str
|
||||
) -> bool:
|
||||
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_session() as db:
|
||||
res = (
|
||||
@@ -252,9 +252,7 @@ class TagTable:
|
||||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
def delete_tags_by_chat_id_and_user_id(
|
||||
self, chat_id: str, user_id: str
|
||||
) -> bool:
|
||||
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
|
||||
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
|
||||
|
||||
for tag in tags:
|
||||
|
||||
@@ -165,9 +165,7 @@ class UsersTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_user_role_by_id(
|
||||
self, id: str, role: str
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
db.query(User).filter_by(id=id).update({"role": role})
|
||||
@@ -193,12 +191,12 @@ class UsersTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_user_last_active_by_id(
|
||||
self, id: str
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"last_active_at": int(time.time())}
|
||||
)
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
@@ -217,9 +215,7 @@ class UsersTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_user_by_id(
|
||||
self, id: str, updated: dict
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
db.query(User).filter_by(id=id).update(updated)
|
||||
|
||||
@@ -78,8 +78,7 @@ async def get_session_user(
|
||||
|
||||
@router.post("/update/profile", response_model=UserResponse)
|
||||
async def update_profile(
|
||||
form_data: UpdateProfileForm,
|
||||
session_user=Depends(get_current_user)
|
||||
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
|
||||
):
|
||||
if session_user:
|
||||
user = Users.update_user_by_id(
|
||||
@@ -101,8 +100,7 @@ async def update_profile(
|
||||
|
||||
@router.post("/update/password", response_model=bool)
|
||||
async def update_password(
|
||||
form_data: UpdatePasswordForm,
|
||||
session_user=Depends(get_current_user)
|
||||
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
|
||||
):
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
||||
@@ -269,9 +267,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
|
||||
@router.post("/add", response_model=SigninResponse)
|
||||
async def add_user(
|
||||
form_data: AddUserForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
@@ -316,9 +312,7 @@ async def add_user(
|
||||
|
||||
|
||||
@router.get("/admin/details")
|
||||
async def get_admin_details(
|
||||
request: Request, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
if request.app.state.config.SHOW_ADMIN_DETAILS:
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
||||
@@ -55,9 +55,7 @@ async def get_session_user_chat_list(
|
||||
|
||||
|
||||
@router.delete("/", response_model=bool)
|
||||
async def delete_all_user_chats(
|
||||
request: Request, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
|
||||
|
||||
if (
|
||||
user.role == "user"
|
||||
@@ -95,9 +93,7 @@ async def get_user_chat_list_by_user_id(
|
||||
|
||||
|
||||
@router.post("/new", response_model=Optional[ChatResponse])
|
||||
async def create_new_chat(
|
||||
form_data: ChatForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
|
||||
try:
|
||||
chat = Chats.insert_new_chat(user.id, form_data)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
@@ -180,9 +176,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
||||
async def get_shared_chat_by_id(
|
||||
share_id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
|
||||
if user.role == "pending":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -225,9 +219,7 @@ async def get_user_chat_list_by_tag_name(
|
||||
)
|
||||
]
|
||||
|
||||
chats = Chats.get_chat_list_by_chat_ids(
|
||||
chat_ids, form_data.skip, form_data.limit
|
||||
)
|
||||
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
|
||||
|
||||
if len(chats) == 0:
|
||||
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
|
||||
@@ -297,9 +289,7 @@ async def update_chat_by_id(
|
||||
|
||||
|
||||
@router.delete("/{id}", response_model=bool)
|
||||
async def delete_chat_by_id(
|
||||
request: Request, id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
|
||||
|
||||
if user.role == "admin":
|
||||
result = Chats.delete_chat_by_id(id)
|
||||
@@ -347,9 +337,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
async def archive_chat_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id)
|
||||
@@ -398,9 +386,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.delete("/{id}/share", response_model=Optional[bool])
|
||||
async def delete_shared_chat_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
if not chat.share_id:
|
||||
@@ -423,9 +409,7 @@ async def delete_shared_chat_by_id(
|
||||
|
||||
|
||||
@router.get("/{id}/tags", response_model=List[TagModel])
|
||||
async def get_chat_tags_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if tags != None:
|
||||
@@ -443,9 +427,7 @@ async def get_chat_tags_by_id(
|
||||
|
||||
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
|
||||
async def add_chat_tag_by_id(
|
||||
id: str,
|
||||
form_data: ChatIdTagForm,
|
||||
user=Depends(get_current_user)
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
|
||||
):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
@@ -494,9 +476,7 @@ async def delete_chat_tag_by_id(
|
||||
|
||||
|
||||
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
||||
async def delete_all_chat_tags_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -44,9 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[DocumentResponse])
|
||||
async def create_new_doc(
|
||||
form_data: DocumentForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
|
||||
doc = Documents.get_doc_by_name(form_data.name)
|
||||
if doc == None:
|
||||
doc = Documents.insert_new_doc(user.id, form_data)
|
||||
@@ -76,9 +74,7 @@ async def create_new_doc(
|
||||
|
||||
|
||||
@router.get("/doc", response_model=Optional[DocumentResponse])
|
||||
async def get_doc_by_name(
|
||||
name: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
|
||||
doc = Documents.get_doc_by_name(name)
|
||||
|
||||
if doc:
|
||||
@@ -110,12 +106,8 @@ class TagDocumentForm(BaseModel):
|
||||
|
||||
|
||||
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
|
||||
async def tag_doc_by_name(
|
||||
form_data: TagDocumentForm, user=Depends(get_current_user)
|
||||
):
|
||||
doc = Documents.update_doc_content_by_name(
|
||||
form_data.name, {"tags": form_data.tags}
|
||||
)
|
||||
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
|
||||
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
|
||||
|
||||
if doc:
|
||||
return DocumentResponse(
|
||||
@@ -163,8 +155,6 @@ async def update_doc_by_name(
|
||||
|
||||
|
||||
@router.delete("/doc/delete", response_model=bool)
|
||||
async def delete_doc_by_name(
|
||||
name: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
|
||||
result = Documents.delete_doc_by_name(name)
|
||||
return result
|
||||
|
||||
@@ -50,10 +50,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/")
|
||||
def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
|
||||
@@ -167,9 +167,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.delete("/{memory_id}", response_model=bool)
|
||||
async def delete_memory_by_id(
|
||||
memory_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
|
||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -29,9 +29,7 @@ async def get_prompts(user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[PromptModel])
|
||||
async def create_new_prompt(
|
||||
form_data: PromptForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
|
||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
||||
if prompt == None:
|
||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
||||
@@ -54,9 +52,7 @@ async def create_new_prompt(
|
||||
|
||||
|
||||
@router.get("/command/{command}", response_model=Optional[PromptModel])
|
||||
async def get_prompt_by_command(
|
||||
command: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
|
||||
if prompt:
|
||||
@@ -95,8 +91,6 @@ async def update_prompt_by_command(
|
||||
|
||||
|
||||
@router.delete("/command/{command}/delete", response_model=bool)
|
||||
async def delete_prompt_by_command(
|
||||
command: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
|
||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
||||
return result
|
||||
|
||||
@@ -180,9 +180,7 @@ async def update_toolkit_by_id(
|
||||
|
||||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_toolkit_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
|
||||
result = Tools.delete_tool_by_id(id)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -40,9 +40,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=List[UserModel])
|
||||
async def get_users(
|
||||
skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
|
||||
):
|
||||
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
|
||||
return Users.get_users(skip, limit)
|
||||
|
||||
|
||||
@@ -70,9 +68,7 @@ async def update_user_permissions(
|
||||
|
||||
|
||||
@router.post("/update/role", response_model=Optional[UserModel])
|
||||
async def update_user_role(
|
||||
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
|
||||
|
||||
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
|
||||
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
||||
@@ -89,9 +85,7 @@ async def update_user_role(
|
||||
|
||||
|
||||
@router.get("/user/settings", response_model=Optional[UserSettings])
|
||||
async def get_user_settings_by_session_user(
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.settings
|
||||
@@ -127,9 +121,7 @@ async def update_user_settings_by_session_user(
|
||||
|
||||
|
||||
@router.get("/user/info", response_model=Optional[dict])
|
||||
async def get_user_info_by_session_user(
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.info
|
||||
@@ -154,9 +146,7 @@ async def update_user_info_by_session_user(
|
||||
if user.info is None:
|
||||
user.info = {}
|
||||
|
||||
user = Users.update_user_by_id(
|
||||
user.id, {"info": {**user.info, **form_data}}
|
||||
)
|
||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
@@ -182,9 +172,7 @@ class UserResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(
|
||||
user_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if user_id is a shared chat
|
||||
# If it is, get the user_id from the chat
|
||||
@@ -267,9 +255,7 @@ async def update_user_by_id(
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=bool)
|
||||
async def delete_user_by_id(
|
||||
user_id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
if user.id != user_id:
|
||||
result = Auths.delete_auth_by_id(user_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user