diff --git a/backend/apps/web/internal/migrations/004_add_archived.py b/backend/apps/web/internal/migrations/004_add_archived.py new file mode 100644 index 000000000..d01c06b4e --- /dev/null +++ b/backend/apps/web/internal/migrations/004_add_archived.py @@ -0,0 +1,46 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields("chat", archived=pw.BooleanField(default=False)) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("chat", "archived") diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py index ef16ce731..0594b5cb0 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/web/models/chats.py @@ -21,6 +21,7 @@ class Chat(Model): chat = TextField() # Save Chat JSON as Text timestamp = DateField() share_id = CharField(null=True, unique=True) + archived = BooleanField(default=False) class Meta: database = DB @@ -33,6 +34,7 @@ class ChatModel(BaseModel): chat: str timestamp: int # timestamp in epoch share_id: Optional[str] = None + archived: bool = False #################### @@ -163,12 +165,27 @@ class ChatTable: except: return None + def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: + try: + chat = self.get_chat_by_id(id) + query = Chat.update( + archived=(not chat.archived), + ).where(Chat.id == id) + + query.execute() + + chat = Chat.get(Chat.id == id) + return ChatModel(**model_to_dict(chat)) + except: + return None + def get_chat_lists_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: return [ ChatModel(**model_to_dict(chat)) for chat in Chat.select() + .where(Chat.archived == False) .where(Chat.user_id == user_id) .order_by(Chat.timestamp.desc()) # .limit(limit) @@ -181,6 +198,7 @@ class ChatTable: return [ ChatModel(**model_to_dict(chat)) for chat in Chat.select() + .where(Chat.archived == False) .where(Chat.id.in_(chat_ids)) .order_by(Chat.timestamp.desc()) ] @@ -188,13 +206,16 @@ class ChatTable: def get_all_chats(self) -> List[ChatModel]: return [ ChatModel(**model_to_dict(chat)) - for chat in Chat.select().order_by(Chat.timestamp.desc()) + for chat in Chat.select() + .where(Chat.archived == False) + .order_by(Chat.timestamp.desc()) ] def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: return [ ChatModel(**model_to_dict(chat)) for chat in Chat.select() + .where(Chat.archived == False) .where(Chat.user_id == user_id) .order_by(Chat.timestamp.desc()) ] diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 2e2bb5b00..8eb89aa50 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -189,6 +189,23 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ return result +############################ +# ArchiveChat +############################ + + +@router.get("/{id}/archive", response_model=Optional[ChatResponse]) +async def archive_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + chat = Chats.toggle_chat_archive_by_id(id) + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # ShareChatById ############################ diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 28b3d4be5..321a5c780 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -282,6 +282,38 @@ export const shareChatById = async (token: string, id: string) => { return res; }; +export const archiveChatById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteSharedChatById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index ef08cf0f8..f05319422 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -17,7 +17,8 @@ getChatById, getChatListByTagName, updateChatById, - getAllChatTags + getAllChatTags, + archiveChatById } from '$lib/apis/chats'; import { toast } from 'svelte-sonner'; import { fade, slide } from 'svelte/transition'; @@ -139,6 +140,11 @@ localStorage.setItem('settings', JSON.stringify($settings)); location.href = '/'; }; + + const archiveChatHandler = async (id) => { + await archiveChatById(localStorage.token, id); + await chats.set(await getChatList(localStorage.token)); + }; @@ -594,7 +600,7 @@ aria-label="Archive" class=" self-center dark:hover:text-white transition" on:click={() => { - selectedChatId = chat.id; + archiveChatHandler(chat.id); }} >