From febab588219a4500d224fb86f2824a3eabe2f193 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 19 May 2024 08:40:46 -0700 Subject: [PATCH] feat: memory integration --- backend/apps/web/routers/memories.py | 2 +- .../Personalization/AddMemoryModal.svelte | 2 +- src/routes/(app)/+page.svelte | 36 ++++++++++++++++--- src/routes/(app)/c/[id]/+page.svelte | 36 ++++++++++++++++--- 4 files changed, 66 insertions(+), 10 deletions(-) diff --git a/backend/apps/web/routers/memories.py b/backend/apps/web/routers/memories.py index 1e61bd478..97dd5f930 100644 --- a/backend/apps/web/routers/memories.py +++ b/backend/apps/web/routers/memories.py @@ -71,7 +71,7 @@ class QueryMemoryForm(BaseModel): content: str -@router.post("/query", response_model=Optional[MemoryModel]) +@router.post("/query") async def query_memory( request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) ): diff --git a/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte b/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte index cac86f3a0..9d67467fe 100644 --- a/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte +++ b/src/lib/components/chat/Settings/Personalization/AddMemoryModal.svelte @@ -26,8 +26,8 @@ if (res) { console.log(res); toast.success('Memory added successfully'); + content = ''; show = false; - dispatch('save'); } diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index db3daf47a..cbb925fde 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -41,6 +41,7 @@ import { LITELLM_API_BASE_URL, OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants'; import { createOpenAITextStream } from '$lib/apis/streaming'; + import { queryMemory } from '$lib/apis/memories'; const i18n = getContext('i18n'); @@ -254,6 +255,26 @@ const sendPrompt = async (prompt, parentId, modelId = null) => { const _chatId = JSON.parse(JSON.stringify($chatId)); + let userContext = null; + + if ($settings?.memory ?? false) { + const res = await queryMemory(localStorage.token, prompt).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + userContext = res.documents.reduce((acc, doc, index) => { + const createdAtTimestamp = res.metadatas[index][0].created_at; + const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0]; + acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`); + return acc; + }, []); + + console.log(userContext); + } + } + await Promise.all( (modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map( async (modelId) => { @@ -270,6 +291,7 @@ role: 'assistant', content: '', model: model.id, + userContext: userContext, timestamp: Math.floor(Date.now() / 1000) // Unix epoch }; @@ -311,10 +333,13 @@ scrollToBottom(); const messagesBody = [ - $settings.system + $settings.system || responseMessage?.userContext ? { role: 'system', - content: $settings.system + content: + $settings.system + (responseMessage?.userContext ?? null) + ? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}` + : '' } : undefined, ...messages @@ -567,10 +592,13 @@ model: model.id, stream: true, messages: [ - $settings.system + $settings.system || responseMessage?.userContext ? { role: 'system', - content: $settings.system + content: + $settings.system + (responseMessage?.userContext ?? null) + ? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}` + : '' } : undefined, ...messages diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index d330aadb3..8680895e0 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -43,6 +43,7 @@ WEBUI_BASE_URL } from '$lib/constants'; import { createOpenAITextStream } from '$lib/apis/streaming'; + import { queryMemory } from '$lib/apis/memories'; const i18n = getContext('i18n'); @@ -260,6 +261,26 @@ const sendPrompt = async (prompt, parentId, modelId = null) => { const _chatId = JSON.parse(JSON.stringify($chatId)); + let userContext = null; + + if ($settings?.memory ?? false) { + const res = await queryMemory(localStorage.token, prompt).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + userContext = res.documents.reduce((acc, doc, index) => { + const createdAtTimestamp = res.metadatas[index][0].created_at; + const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0]; + acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`); + return acc; + }, []); + + console.log(userContext); + } + } + await Promise.all( (modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map( async (modelId) => { @@ -317,10 +338,13 @@ scrollToBottom(); const messagesBody = [ - $settings.system + $settings.system || responseMessage?.userContext ? { role: 'system', - content: $settings.system + content: + $settings.system + (responseMessage?.userContext ?? null) + ? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}` + : '' } : undefined, ...messages @@ -573,10 +597,13 @@ model: model.id, stream: true, messages: [ - $settings.system + $settings.system || responseMessage?.userContext ? { role: 'system', - content: $settings.system + content: + $settings.system + (responseMessage?.userContext ?? null) + ? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}` + : '' } : undefined, ...messages @@ -705,6 +732,7 @@ } catch (error) { await handleOpenAIError(error, null, model, responseMessage); } + messages = messages; stopResponseFlag = false; await tick();