mirror of
https://github.com/open-webui/open-webui
synced 2025-02-16 18:22:29 +00:00
refac: rag
This commit is contained in:
parent
277fc3feac
commit
f2b9a5f5bf
@ -20,7 +20,7 @@ from langchain.retrievers import (
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -247,31 +247,7 @@ def rag_messages(
|
|||||||
hybrid_search,
|
hybrid_search,
|
||||||
):
|
):
|
||||||
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
|
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
|
||||||
|
query = get_last_user_message(messages)
|
||||||
last_user_message_idx = None
|
|
||||||
for i in range(len(messages) - 1, -1, -1):
|
|
||||||
if messages[i]["role"] == "user":
|
|
||||||
last_user_message_idx = i
|
|
||||||
break
|
|
||||||
|
|
||||||
user_message = messages[last_user_message_idx]
|
|
||||||
|
|
||||||
if isinstance(user_message["content"], list):
|
|
||||||
# Handle list content input
|
|
||||||
content_type = "list"
|
|
||||||
query = ""
|
|
||||||
for content_item in user_message["content"]:
|
|
||||||
if content_item["type"] == "text":
|
|
||||||
query = content_item["text"]
|
|
||||||
break
|
|
||||||
elif isinstance(user_message["content"], str):
|
|
||||||
# Handle text content input
|
|
||||||
content_type = "text"
|
|
||||||
query = user_message["content"]
|
|
||||||
else:
|
|
||||||
# Fallback in case the input does not match expected types
|
|
||||||
content_type = None
|
|
||||||
query = ""
|
|
||||||
|
|
||||||
extracted_collections = []
|
extracted_collections = []
|
||||||
relevant_contexts = []
|
relevant_contexts = []
|
||||||
@ -349,24 +325,7 @@ def rag_messages(
|
|||||||
)
|
)
|
||||||
|
|
||||||
log.debug(f"ra_content: {ra_content}")
|
log.debug(f"ra_content: {ra_content}")
|
||||||
|
messages = add_or_update_system_message(ra_content, messages)
|
||||||
if content_type == "list":
|
|
||||||
new_content = []
|
|
||||||
for content_item in user_message["content"]:
|
|
||||||
if content_item["type"] == "text":
|
|
||||||
# Update the text item's content with ra_content
|
|
||||||
new_content.append({"type": "text", "text": ra_content})
|
|
||||||
else:
|
|
||||||
# Keep other types of content as they are
|
|
||||||
new_content.append(content_item)
|
|
||||||
new_user_message = {**user_message, "content": new_content}
|
|
||||||
else:
|
|
||||||
new_user_message = {
|
|
||||||
**user_message,
|
|
||||||
"content": ra_content,
|
|
||||||
}
|
|
||||||
|
|
||||||
messages[last_user_message_idx] = new_user_message
|
|
||||||
|
|
||||||
return messages, citations
|
return messages, citations
|
||||||
|
|
||||||
|
@ -3,7 +3,48 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_user_message(messages: List[dict]) -> str:
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message["role"] == "user":
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
for item in message["content"]:
|
||||||
|
if item["type"] == "text":
|
||||||
|
return item["text"]
|
||||||
|
return message["content"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_assistant_message(messages: List[dict]) -> str:
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
for item in message["content"]:
|
||||||
|
if item["type"] == "text":
|
||||||
|
return item["text"]
|
||||||
|
return message["content"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def add_or_update_system_message(content: str, messages: List[dict]):
|
||||||
|
"""
|
||||||
|
Adds a new system message at the beginning of the messages list
|
||||||
|
or updates the existing system message at the beginning.
|
||||||
|
|
||||||
|
:param msg: The message to be added or appended.
|
||||||
|
:param messages: The list of message dictionaries.
|
||||||
|
:return: The updated list of message dictionaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if messages and messages[0].get("role") == "system":
|
||||||
|
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
|
||||||
|
else:
|
||||||
|
# Insert at the beginning
|
||||||
|
messages.insert(0, {"role": "system", "content": content})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_gravatar_url(email):
|
def get_gravatar_url(email):
|
||||||
|
@ -23,7 +23,7 @@
|
|||||||
let selectedTab = 'general';
|
let selectedTab = 'general';
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="flex flex-col md:flex-row w-full h-full py-3 md:space-x-4">
|
<div class="flex flex-col md:flex-row w-full h-full py-2 md:space-x-4">
|
||||||
<div
|
<div
|
||||||
class="tabs flex flex-row overflow-x-auto space-x-1 md:space-x-0 md:space-y-1 md:flex-col md:flex-none md:w-44 dark:text-gray-200 text-xs text-left scrollbar-none"
|
class="tabs flex flex-row overflow-x-auto space-x-1 md:space-x-0 md:space-y-1 md:flex-col md:flex-none md:w-44 dark:text-gray-200 text-xs text-left scrollbar-none"
|
||||||
>
|
>
|
||||||
|
Loading…
Reference in New Issue
Block a user