diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index a3f9b45f9..f0ff460c9 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -41,6 +41,7 @@ from open_webui.routers.pipelines import ( process_pipeline_inlet_filter, process_pipeline_outlet_filter, ) +from open_webui.routers.memories import query_memory, QueryMemoryForm from open_webui.utils.webhook import post_webhook @@ -290,6 +291,38 @@ async def chat_completion_tools_handler( return body, {"sources": sources} +async def chat_memory_handler( + request: Request, form_data: dict, extra_params: dict, user +): + results = await query_memory( + request, + QueryMemoryForm( + **{"content": get_last_user_message(form_data["messages"]), "k": 3} + ), + user, + ) + + user_context = "" + if results and hasattr(results, "documents"): + if results.documents and len(results.documents) > 0: + for doc_idx, doc in enumerate(results.documents[0]): + created_at_date = "Unknown Date" + + if results.metadatas[0][doc_idx].get("created_at"): + created_at_timestamp = results.metadatas[0][doc_idx]["created_at"] + created_at_date = time.strftime( + "%Y-%m-%d", time.localtime(created_at_timestamp) + ) + + user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n" + + form_data["messages"] = add_or_update_system_message( + f"User Context:\n{user_context}\n", form_data["messages"], append=True + ) + + return form_data + + async def chat_web_search_handler( request: Request, form_data: dict, extra_params: dict, user ): @@ -774,6 +807,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): features = form_data.pop("features", None) if features: + if "memory" in features and features["memory"]: + form_data = await chat_memory_handler( + request, form_data, extra_params, user + ) + if "web_search" in features and features["web_search"]: form_data = await chat_web_search_handler( request, form_data, extra_params, user diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 98938dfea..b804afd2c 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -130,7 +130,9 @@ def prepend_to_first_user_message_content( return messages -def add_or_update_system_message(content: str, messages: list[dict]): +def add_or_update_system_message( + content: str, messages: list[dict], append: bool = False +): """ Adds a new system message at the beginning of the messages list or updates the existing system message at the beginning. @@ -141,7 +143,10 @@ def add_or_update_system_message(content: str, messages: list[dict]): """ if messages and messages[0].get("role") == "system": - messages[0]["content"] = f"{content}\n{messages[0]['content']}" + if append: + messages[0]["content"] = f"{messages[0]['content']}\n{content}" + else: + messages[0]["content"] = f"{content}\n{messages[0]['content']}" else: # Insert at the beginning messages.insert(0, {"role": "system", "content": content}) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index cee724e30..d30cb811d 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1431,7 +1431,6 @@ model: model.id, modelName: model.name ?? model.id, modelIdx: modelIdx ? modelIdx : _modelIdx, - userContext: null, timestamp: Math.floor(Date.now() / 1000) // Unix epoch }; @@ -1486,32 +1485,6 @@ let responseMessageId = responseMessageIds[`${modelId}-${modelIdx ? modelIdx : _modelIdx}`]; - let responseMessage = _history.messages[responseMessageId]; - - let userContext = null; - if ($settings?.memory ?? false) { - if (userContext === null) { - const res = await queryMemory(localStorage.token, prompt).catch((error) => { - toast.error(`${error}`); - return null; - }); - if (res) { - if (res.documents[0].length > 0) { - userContext = res.documents[0].reduce((acc, doc, index) => { - const createdAtTimestamp = res.metadatas[0][index].created_at; - const createdAtDate = new Date(createdAtTimestamp * 1000) - .toISOString() - .split('T')[0]; - return `${acc}${index + 1}. [${createdAtDate}]. ${doc}\n`; - }, ''); - } - - console.log(userContext); - } - } - } - responseMessage.userContext = userContext; - const chatEventEmitter = await getChatEventEmitter(model.id, _chatId); scrollToBottom(); @@ -1573,7 +1546,7 @@ true; let messages = [ - params?.system || $settings.system || (responseMessage?.userContext ?? null) + params?.system || $settings.system ? { role: 'system', content: `${promptTemplate( @@ -1585,11 +1558,7 @@ return undefined; }) : undefined - )}${ - (responseMessage?.userContext ?? null) - ? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}` - : '' - }` + )}` } : undefined, ...createMessagesList(_history, responseMessageId).map((message) => ({ @@ -1666,7 +1635,8 @@ $config?.features?.enable_web_search && ($user?.role === 'admin' || $user?.permissions?.features?.web_search) ? webSearchEnabled || ($settings?.webSearch ?? false) === 'always' - : false + : false, + memory: $settings?.memory ?? false }, variables: { ...getPromptVariables(