From 493e3068d806520b7fde713cd52740c236758a61 Mon Sep 17 00:00:00 2001 From: Peter De-Ath Date: Thu, 13 Jun 2024 02:01:50 +0100 Subject: [PATCH] enh: ability to edit memories --- backend/apps/webui/models/memories.py | 14 +++++ backend/apps/webui/routers/memories.py | 23 ++++++++ src/lib/apis/memories/index.ts | 31 +++++++++++ .../Personalization/AddMemoryModal.svelte | 52 +++++++++++++++++-- .../Personalization/ManageModal.svelte | 26 ++++++++-- src/lib/i18n/locales/en-GB/translation.json | 2 + 6 files changed, 141 insertions(+), 7 deletions(-) diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 70e5577e9..0266cc8b2 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -64,6 +64,20 @@ class MemoriesTable: return memory else: return None + + def update_memory( + self, + id: str, + content: str, + ) -> Optional[MemoryModel]: + try: + memory = Memory.get(Memory.id == id) + memory.content = content + memory.updated_at = int(time.time()) + memory.save() + return MemoryModel(**model_to_dict(memory)) + except: + return None def get_memories(self) -> List[MemoryModel]: try: diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index 6448ebe1e..927c28b46 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -43,6 +43,8 @@ async def get_memories(user=Depends(get_verified_user)): class AddMemoryForm(BaseModel): content: str +class MemoryUpdateModel(BaseModel): + content: Optional[str] = None @router.post("/add", response_model=Optional[MemoryModel]) async def add_memory( @@ -62,6 +64,27 @@ async def add_memory( return memory +@router.patch("/{memory_id}", response_model=Optional[MemoryModel]) +async def update_memory( + memory_id: str, request: Request, form_data: MemoryUpdateModel, user=Depends(get_verified_user) +): + memory = Memories.update_memory(memory_id, form_data.content) + if memory is None: + raise HTTPException(status_code=404, detail="Memory not found") + + if form_data.content is not None: + memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) + collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + collection.upsert( + documents=[form_data.content], + ids=[memory.id], + embeddings=[memory_embedding], + metadatas=[{"created_at": memory.created_at, "updated_at": memory.updated_at}], + ) + + return memory + + ############################ # QueryMemory ############################ diff --git a/src/lib/apis/memories/index.ts b/src/lib/apis/memories/index.ts index 6cbb89f14..cc4abb176 100644 --- a/src/lib/apis/memories/index.ts +++ b/src/lib/apis/memories/index.ts @@ -59,6 +59,37 @@ export const addNewMemory = async (token: string, content: string) => { return res; }; +export const updateMemoryById = async (token: string, id: string, content: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/memories/${id}`, { + method: 'PATCH', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + content: content + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const queryMemory = async (token: string, content: string) => { let error = null; diff --git a/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte b/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte index 445b7f667..ff9476308 100644 --- a/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte +++ b/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte @@ -2,21 +2,60 @@ import { createEventDispatcher, getContext } from 'svelte'; import Modal from '$lib/components/common/Modal.svelte'; - import { addNewMemory } from '$lib/apis/memories'; + import { addNewMemory, updateMemoryById } from '$lib/apis/memories'; import { toast } from 'svelte-sonner'; const dispatch = createEventDispatcher(); export let show; + export let memory = {}; + + let showUpdateBtn = false; const i18n = getContext('i18n'); let loading = false; let content = ''; + let isMemoryLoaded = false; + + $: { + if (memory && memory.id && !isMemoryLoaded) { + showUpdateBtn = true; + content = memory.content; + isMemoryLoaded = true; + } + if (!show) { + showUpdateBtn = false; + isMemoryLoaded = false; + memory = {}; + content = ''; + } + } const submitHandler = async () => { loading = true; + if (memory && memory.id) { + const res = await updateMemoryById(localStorage.token, memory.id, content).catch((error) => { + toast.error(error); + + return null; + }); + + if (res) { + console.log(res); + toast.success('Memory updated successfully'); + content = ''; + show = false; + isMemoryLoaded = false; + memory = {}; + dispatch('save'); + } + + loading = false; + return; + } + const res = await addNewMemory(localStorage.token, content).catch((error) => { toast.error(error); @@ -38,7 +77,9 @@
-
{$i18n.t('Add Memory')}
+
+ {memory.id ? $i18n.t('Edit Memory') : $i18n.t('Add Memory')} +
+ +