refac/enh: db session sharing
This commit is contained in:
@@ -4,7 +4,8 @@ import uuid
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import Base, get_db, get_db_context
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||
@@ -211,11 +212,9 @@ class NoteTable:
|
||||
return query
|
||||
|
||||
def insert_new_note(
|
||||
self,
|
||||
form_data: NoteForm,
|
||||
user_id: str,
|
||||
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
note = NoteModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
@@ -233,9 +232,9 @@ class NoteTable:
|
||||
return note
|
||||
|
||||
def get_notes(
|
||||
self, skip: Optional[int] = None, limit: Optional[int] = None
|
||||
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
||||
if skip is not None:
|
||||
query = query.offset(skip)
|
||||
@@ -333,10 +332,11 @@ class NoteTable:
|
||||
self,
|
||||
user_id: str,
|
||||
permission: str = "read",
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
||||
]
|
||||
@@ -354,15 +354,17 @@ class NoteTable:
|
||||
notes = query.all()
|
||||
return [NoteModel.model_validate(note) for note in notes]
|
||||
|
||||
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
def get_note_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
with get_db_context(db) as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
return NoteModel.model_validate(note) if note else None
|
||||
|
||||
def update_note_by_id(
|
||||
self, id: str, form_data: NoteUpdateForm
|
||||
self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
if not note:
|
||||
return None
|
||||
@@ -384,11 +386,14 @@ class NoteTable:
|
||||
db.commit()
|
||||
return NoteModel.model_validate(note) if note else None
|
||||
|
||||
def delete_note_by_id(self, id: str):
|
||||
with get_db() as db:
|
||||
db.query(Note).filter(Note.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Note).filter(Note.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Notes = NoteTable()
|
||||
|
||||
Reference in New Issue
Block a user