Merge branch 'upstream-dev' into dev

This commit is contained in:
Jannik Streidl
2024-10-14 09:50:40 +02:00
39 changed files with 1235 additions and 469 deletions

View File

@@ -19,6 +19,7 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import DEFAULT_RAG_TEMPLATE
log = logging.getLogger(__name__)
@@ -239,8 +240,13 @@ def query_collection_with_hybrid_search(
def rag_template(template: str, context: str, query: str):
count = template.count("[context]")
assert "[context]" in template, "RAG template does not contain '[context]'"
if template == "":
template = DEFAULT_RAG_TEMPLATE
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context:
log.debug(
@@ -249,14 +255,25 @@ def rag_template(template: str, context: str, query: str):
"nothing, or the user might be trying to hack something."
)
query_placeholders = []
if "[query]" in context:
query_placeholder = f"[query-{str(uuid.uuid4())}]"
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("[query]", query_placeholder)
template = template.replace("[context]", context)
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
else:
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
@@ -375,8 +392,21 @@ def get_rag_context(
for context in relevant_contexts:
try:
if "documents" in context:
file_names = list(
set(
[
metadata["name"]
for metadata in context["metadatas"][0]
if metadata is not None and "name" in metadata
]
)
)
contexts.append(
"\n\n".join(
(", ".join(file_names) + ":\n\n")
if file_names
else ""
+ "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
@@ -393,6 +423,7 @@ def get_rag_context(
except Exception as e:
log.exception(e)
print(contexts, citations)
return contexts, citations