diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index ac52dc3d8..7d92dd10f 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -20,7 +20,7 @@ from langchain.retrievers import ( from typing import Optional - +from utils.misc import get_last_user_message, add_or_update_system_message from config import SRC_LOG_LEVELS, CHROMA_CLIENT log = logging.getLogger(__name__) @@ -247,31 +247,7 @@ def rag_messages( hybrid_search, ): log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") - - last_user_message_idx = None - for i in range(len(messages) - 1, -1, -1): - if messages[i]["role"] == "user": - last_user_message_idx = i - break - - user_message = messages[last_user_message_idx] - - if isinstance(user_message["content"], list): - # Handle list content input - content_type = "list" - query = "" - for content_item in user_message["content"]: - if content_item["type"] == "text": - query = content_item["text"] - break - elif isinstance(user_message["content"], str): - # Handle text content input - content_type = "text" - query = user_message["content"] - else: - # Fallback in case the input does not match expected types - content_type = None - query = "" + query = get_last_user_message(messages) extracted_collections = [] relevant_contexts = [] @@ -349,24 +325,7 @@ def rag_messages( ) log.debug(f"ra_content: {ra_content}") - - if content_type == "list": - new_content = [] - for content_item in user_message["content"]: - if content_item["type"] == "text": - # Update the text item's content with ra_content - new_content.append({"type": "text", "text": ra_content}) - else: - # Keep other types of content as they are - new_content.append(content_item) - new_user_message = {**user_message, "content": new_content} - else: - new_user_message = { - **user_message, - "content": ra_content, - } - - messages[last_user_message_idx] = new_user_message + messages = add_or_update_system_message(ra_content, messages) return messages, citations diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 4bb5ddf56..c3c65d3f5 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -3,7 +3,48 @@ import hashlib import json import re from datetime import timedelta -from typing import Optional +from typing import Optional, List + + +def get_last_user_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message["role"] == "user": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] + return None + + +def get_last_assistant_message(messages: List[dict]) -> str: + for message in reversed(messages): + if message["role"] == "assistant": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + return item["text"] + return message["content"] + return None + + +def add_or_update_system_message(content: str, messages: List[dict]): + """ + Adds a new system message at the beginning of the messages list + or updates the existing system message at the beginning. + + :param msg: The message to be added or appended. + :param messages: The list of message dictionaries. + :return: The updated list of message dictionaries. + """ + + if messages and messages[0].get("role") == "system": + messages[0]["content"] += f"{content}\n{messages[0]['content']}" + else: + # Insert at the beginning + messages.insert(0, {"role": "system", "content": content}) + + return messages def get_gravatar_url(email): diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index f668da2ff..982a5511a 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -23,7 +23,7 @@ let selectedTab = 'general'; -