refac/enh: db session sharing

This commit is contained in:
Timothy Jaeryang Baek
2025-12-28 22:00:44 +04:00
parent d4de26bd05
commit 2041ab483e
20 changed files with 600 additions and 562 deletions

View File

@@ -4,7 +4,8 @@ import uuid
from typing import Optional
from functools import lru_cache
from open_webui.internal.db import Base, get_db
from sqlalchemy.orm import Session
from open_webui.internal.db import Base, get_db, get_db_context
from open_webui.models.groups import Groups
from open_webui.utils.access_control import has_access
from open_webui.models.users import User, UserModel, Users, UserResponse
@@ -211,11 +212,9 @@ class NoteTable:
return query
def insert_new_note(
self,
form_data: NoteForm,
user_id: str,
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
) -> Optional[NoteModel]:
with get_db() as db:
with get_db_context(db) as db:
note = NoteModel(
**{
"id": str(uuid.uuid4()),
@@ -233,9 +232,9 @@ class NoteTable:
return note
def get_notes(
self, skip: Optional[int] = None, limit: Optional[int] = None
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[NoteModel]:
with get_db() as db:
with get_db_context(db) as db:
query = db.query(Note).order_by(Note.updated_at.desc())
if skip is not None:
query = query.offset(skip)
@@ -333,10 +332,11 @@ class NoteTable:
self,
user_id: str,
permission: str = "read",
skip: Optional[int] = None,
limit: Optional[int] = None,
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[NoteModel]:
with get_db() as db:
with get_db_context(db) as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id)
]
@@ -354,15 +354,17 @@ class NoteTable:
notes = query.all()
return [NoteModel.model_validate(note) for note in notes]
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db:
def get_note_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[NoteModel]:
with get_db_context(db) as db:
note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None
def update_note_by_id(
self, id: str, form_data: NoteUpdateForm
self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None
) -> Optional[NoteModel]:
with get_db() as db:
with get_db_context(db) as db:
note = db.query(Note).filter(Note.id == id).first()
if not note:
return None
@@ -384,11 +386,14 @@ class NoteTable:
db.commit()
return NoteModel.model_validate(note) if note else None
def delete_note_by_id(self, id: str):
with get_db() as db:
db.query(Note).filter(Note.id == id).delete()
db.commit()
return True
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
db.query(Note).filter(Note.id == id).delete()
db.commit()
return True
except Exception:
return False
Notes = NoteTable()