feat(sqlalchemy): format backend

This commit is contained in:
Jonathan Rohde 2024-06-24 13:55:18 +02:00
parent 2fb27adbf6
commit d88bd51e3c
8 changed files with 153 additions and 142 deletions

View File

@ -85,9 +85,7 @@ class ChatTable:
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": ( "title": (
form_data.chat["title"] form_data.chat["title"] if "title" in form_data.chat else "New Chat"
if "title" in form_data.chat
else "New Chat"
), ),
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"created_at": int(time.time()), "created_at": int(time.time()),
@ -197,14 +195,14 @@ class ChatTable:
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
all_chats = ( all_chats = (
Session.query(Chat) Session.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
.all() .all()
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id( def get_chat_list_by_user_id(
self, self,

View File

@ -115,9 +115,7 @@ class MemoriesTable:
except: except:
return False return False
def delete_memory_by_id_and_user_id( def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
self, id: str, user_id: str
) -> bool:
try: try:
Session.query(Memory).filter_by(id=id, user_id=user_id).delete() Session.query(Memory).filter_by(id=id, user_id=user_id).delete()
return True return True

View File

@ -140,7 +140,9 @@ class ModelsTable:
return None return None
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
return [ModelModel.model_validate(model) for model in Session.query(Model).all()] return [
ModelModel.model_validate(model) for model in Session.query(Model).all()
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:

View File

@ -207,9 +207,7 @@ class TagTable:
log.debug(f"res: {res}") log.debug(f"res: {res}")
Session.commit() Session.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
tag_name, user_id
)
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
@ -230,9 +228,7 @@ class TagTable:
log.debug(f"res: {res}") log.debug(f"res: {res}")
Session.commit() Session.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
tag_name, user_id
)
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()

View File

@ -793,6 +793,7 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
@app.middleware("http") @app.middleware("http")
async def commit_session_after_request(request: Request, call_next): async def commit_session_after_request(request: Request, call_next):
response = await call_next(request) response = await call_next(request)

View File

@ -5,6 +5,7 @@ Revises:
Create Date: 2024-06-24 13:15:33.808998 Create Date: 2024-06-24 13:15:33.808998
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
@ -13,7 +14,7 @@ import apps.webui.internal.db
from migrations.util import get_existing_tables from migrations.util import get_existing_tables
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '7e5b5dc7342b' revision: str = "7e5b5dc7342b"
down_revision: Union[str, None] = None down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -24,163 +25,175 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
if "auth" not in existing_tables: if "auth" not in existing_tables:
op.create_table('auth', op.create_table(
sa.Column('id', sa.String(), nullable=False), "auth",
sa.Column('email', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('password', sa.Text(), nullable=True), sa.Column("email", sa.String(), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True), sa.Column("password", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "chat" not in existing_tables: if "chat" not in existing_tables:
op.create_table('chat', op.create_table(
sa.Column('id', sa.String(), nullable=False), "chat",
sa.Column('user_id', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('title', sa.Text(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.Column('chat', sa.Text(), nullable=True), sa.Column("title", sa.Text(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True), sa.Column("chat", sa.Text(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True), sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column('share_id', sa.Text(), nullable=True), sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column('archived', sa.Boolean(), nullable=True), sa.Column("share_id", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id'), sa.Column("archived", sa.Boolean(), nullable=True),
sa.UniqueConstraint('share_id') sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
) )
if "chatidtag" not in existing_tables: if "chatidtag" not in existing_tables:
op.create_table('chatidtag', op.create_table(
sa.Column('id', sa.String(), nullable=False), "chatidtag",
sa.Column('tag_name', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('chat_id', sa.String(), nullable=True), sa.Column("tag_name", sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True), sa.Column("chat_id", sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "document" not in existing_tables: if "document" not in existing_tables:
op.create_table('document', op.create_table(
sa.Column('collection_name', sa.String(), nullable=False), "document",
sa.Column('name', sa.String(), nullable=True), sa.Column("collection_name", sa.String(), nullable=False),
sa.Column('title', sa.Text(), nullable=True), sa.Column("name", sa.String(), nullable=True),
sa.Column('filename', sa.Text(), nullable=True), sa.Column("title", sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True), sa.Column("filename", sa.Text(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True), sa.Column("content", sa.Text(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.PrimaryKeyConstraint('collection_name'), sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.UniqueConstraint('name') sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
) )
if "file" not in existing_tables: if "file" not in existing_tables:
op.create_table('file', op.create_table(
sa.Column('id', sa.String(), nullable=False), "file",
sa.Column('user_id', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('filename', sa.Text(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("filename", sa.Text(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True), sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "function" not in existing_tables: if "function" not in existing_tables:
op.create_table('function', op.create_table(
sa.Column('id', sa.String(), nullable=False), "function",
sa.Column('user_id', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('name', sa.Text(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.Column('type', sa.Text(), nullable=True), sa.Column("name", sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True), sa.Column("type", sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("content", sa.Text(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True), sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True), sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True), sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "memory" not in existing_tables: if "memory" not in existing_tables:
op.create_table('memory', op.create_table(
sa.Column('id', sa.String(), nullable=False), "memory",
sa.Column('user_id', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('content', sa.Text(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True), sa.Column("content", sa.Text(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True), sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "model" not in existing_tables: if "model" not in existing_tables:
op.create_table('model', op.create_table(
sa.Column('id', sa.Text(), nullable=False), "model",
sa.Column('user_id', sa.Text(), nullable=True), sa.Column("id", sa.Text(), nullable=False),
sa.Column('base_model_id', sa.Text(), nullable=True), sa.Column("user_id", sa.Text(), nullable=True),
sa.Column('name', sa.Text(), nullable=True), sa.Column("base_model_id", sa.Text(), nullable=True),
sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("name", sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True), sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True), sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "prompt" not in existing_tables: if "prompt" not in existing_tables:
op.create_table('prompt', op.create_table(
sa.Column('command', sa.String(), nullable=False), "prompt",
sa.Column('user_id', sa.String(), nullable=True), sa.Column("command", sa.String(), nullable=False),
sa.Column('title', sa.Text(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.Column('content', sa.Text(), nullable=True), sa.Column("title", sa.Text(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True), sa.Column("content", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('command') sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
) )
if "tag" not in existing_tables: if "tag" not in existing_tables:
op.create_table('tag', op.create_table(
sa.Column('id', sa.String(), nullable=False), "tag",
sa.Column('name', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True), sa.Column("name", sa.String(), nullable=True),
sa.Column('data', sa.Text(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("data", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "tool" not in existing_tables: if "tool" not in existing_tables:
op.create_table('tool', op.create_table(
sa.Column('id', sa.String(), nullable=False), "tool",
sa.Column('user_id', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('name', sa.Text(), nullable=True), sa.Column("user_id", sa.String(), nullable=True),
sa.Column('content', sa.Text(), nullable=True), sa.Column("name", sa.Text(), nullable=True),
sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("content", sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True), sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True), sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
if "user" not in existing_tables: if "user" not in existing_tables:
op.create_table('user', op.create_table(
sa.Column('id', sa.String(), nullable=False), "user",
sa.Column('name', sa.String(), nullable=True), sa.Column("id", sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True), sa.Column("name", sa.String(), nullable=True),
sa.Column('role', sa.String(), nullable=True), sa.Column("email", sa.String(), nullable=True),
sa.Column('profile_image_url', sa.Text(), nullable=True), sa.Column("role", sa.String(), nullable=True),
sa.Column('last_active_at', sa.BigInteger(), nullable=True), sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True), sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True), sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column('api_key', sa.String(), nullable=True), sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("api_key", sa.String(), nullable=True),
sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True), sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint('id'), sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.UniqueConstraint('api_key') sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
) )
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table('user') op.drop_table("user")
op.drop_table('tool') op.drop_table("tool")
op.drop_table('tag') op.drop_table("tag")
op.drop_table('prompt') op.drop_table("prompt")
op.drop_table('model') op.drop_table("model")
op.drop_table('memory') op.drop_table("memory")
op.drop_table('function') op.drop_table("function")
op.drop_table('file') op.drop_table("file")
op.drop_table('document') op.drop_table("document")
op.drop_table('chatidtag') op.drop_table("chatidtag")
op.drop_table('chat') op.drop_table("chat")
op.drop_table('auth') op.drop_table("auth")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -91,6 +91,7 @@ class TestChats(AbstractPostgresTest):
def test_get_user_archived_chats(self): def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id("2") self.chats.archive_all_chats_by_user_id("2")
from apps.webui.internal.db import Session from apps.webui.internal.db import Session
Session.commit() Session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived")) response = self.fast_api_client.get(self.create_url("/all/archived"))

View File

@ -110,6 +110,7 @@ class AbstractPostgresTest(AbstractIntegrationTest):
def _check_db_connection(self): def _check_db_connection(self):
from apps.webui.internal.db import Session from apps.webui.internal.db import Session
retries = 10 retries = 10
while retries > 0: while retries > 0:
try: try:
@ -133,6 +134,7 @@ class AbstractPostgresTest(AbstractIntegrationTest):
def teardown_method(self): def teardown_method(self):
from apps.webui.internal.db import Session from apps.webui.internal.db import Session
# rollback everything not yet committed # rollback everything not yet committed
Session.commit() Session.commit()