mirror of
https://github.com/open-webui/open-webui
synced 2025-01-29 13:58:09 +00:00
Merge pull request #5378 from thiswillbeyourgithub/fix_RAG_and_web
fix: RAG and Web Search + RAG enhancements
This commit is contained in:
commit
7dc4cb30b2
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
@ -91,7 +92,7 @@ def query_doc_with_hybrid_search(
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
):
|
||||
) -> dict:
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
@ -134,7 +135,7 @@ def query_doc_with_hybrid_search(
|
||||
raise e
|
||||
|
||||
|
||||
def merge_and_sort_query_results(query_results, k, reverse=False):
|
||||
def merge_and_sort_query_results(query_results: list[dict], k: int, reverse: bool = False) -> list[dict]:
|
||||
# Initialize lists to store combined data
|
||||
combined_distances = []
|
||||
combined_documents = []
|
||||
@ -180,7 +181,7 @@ def query_collection(
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
):
|
||||
) -> dict:
|
||||
results = []
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
@ -192,8 +193,8 @@ def query_collection(
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
@ -207,8 +208,9 @@ def query_collection_with_hybrid_search(
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
):
|
||||
) -> dict:
|
||||
results = []
|
||||
failed = 0
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
result = query_doc_with_hybrid_search(
|
||||
@ -220,14 +222,39 @@ def query_collection_with_hybrid_search(
|
||||
r=r,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
"Error when querying the collection with "
|
||||
f"hybrid_search: {e}"
|
||||
)
|
||||
failed += 1
|
||||
if failed == len(collection_names):
|
||||
raise Exception("Hybrid search failed for all collections. Using "
|
||||
"Non hybrid search as fallback.")
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
||||
|
||||
def rag_template(template: str, context: str, query: str):
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace("[query]", query)
|
||||
count = template.count("[context]")
|
||||
assert count == 1, (
|
||||
f"RAG template contains an unexpected number of '[context]' : {count}"
|
||||
)
|
||||
assert "[context]" in template, "RAG template does not contain '[context]'"
|
||||
if "<context>" in context and "</context>" in context:
|
||||
log.debug(
|
||||
"WARNING: Potential prompt injection attack: the RAG "
|
||||
"context contains '<context>' and '</context>'. This might be "
|
||||
"nothing, or the user might be trying to hack something."
|
||||
)
|
||||
|
||||
if "[query]" in context:
|
||||
query_placeholder = str(uuid.uuid4())
|
||||
template = template.replace("[QUERY]", query_placeholder)
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace(query_placeholder, query)
|
||||
else:
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace("[query]", query)
|
||||
return template
|
||||
|
||||
|
||||
@ -304,19 +331,25 @@ def get_rag_context(
|
||||
continue
|
||||
|
||||
try:
|
||||
context = None
|
||||
if file["type"] == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug("Error when using hybrid search, using"
|
||||
" non hybrid search as fallback.")
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
query=query,
|
||||
@ -325,7 +358,6 @@ def get_rag_context(
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
context = None
|
||||
|
||||
if context:
|
||||
relevant_contexts.append({**context, "source": file})
|
||||
|
@ -1030,19 +1030,25 @@ CHUNK_OVERLAP = PersistentConfig(
|
||||
int(os.environ.get("CHUNK_OVERLAP", "100")),
|
||||
)
|
||||
|
||||
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
|
||||
|
||||
<context>
|
||||
[context]
|
||||
[context]
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
<rules>
|
||||
- If you don't know, just say so.
|
||||
- If you are not sure, ask for clarification.
|
||||
- Answer in the same language as the user query.
|
||||
- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
|
||||
- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
|
||||
- Answer directly and without using xml tags.
|
||||
</rules>
|
||||
|
||||
Given the context information, answer the query.
|
||||
Query: [query]"""
|
||||
<user_query>
|
||||
[query]
|
||||
</user_query>
|
||||
"""
|
||||
|
||||
RAG_TEMPLATE = PersistentConfig(
|
||||
"RAG_TEMPLATE",
|
||||
|
@ -588,6 +588,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
prompt = get_last_user_message(body["messages"])
|
||||
if prompt is None:
|
||||
raise Exception("No user message found")
|
||||
if rag_app.state.config.RELEVANCE_THRESHOLD == 0:
|
||||
assert context_string.strip(), (
|
||||
"With a 0 relevancy threshold for RAG, the context cannot "
|
||||
"be empty"
|
||||
)
|
||||
# Workaround for Ollama 2.0+ system prompt issue
|
||||
# TODO: replace with add_or_update_system_message
|
||||
if model["owned_by"] == "ollama":
|
||||
|
Loading…
Reference in New Issue
Block a user