open-webui/backend/open_webui/apps/retrieval/utils.py

586 lines
18 KiB
Python
Raw Normal View History

import logging
2024-08-27 22:10:27 +00:00
import os
import uuid
2024-08-27 22:10:27 +00:00
from typing import Optional, Union
2024-03-09 03:26:39 +00:00
2024-11-16 12:41:07 +00:00
import asyncio
2024-08-27 22:10:27 +00:00
import requests
2024-09-10 01:27:50 +00:00
2024-04-25 12:49:59 +00:00
from huggingface_hub import snapshot_download
2024-08-27 22:10:27 +00:00
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
2024-08-27 22:10:27 +00:00
from langchain_core.documents import Document
2024-09-10 01:27:50 +00:00
from open_webui.apps.ollama.main import (
GenerateEmbedForm,
generate_ollama_batch_embeddings,
2024-09-10 01:27:50 +00:00
)
2024-09-27 23:28:45 +00:00
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
2024-04-14 23:48:15 +00:00
2024-09-10 01:27:50 +00:00
from open_webui.env import SRC_LOG_LEVELS
2024-10-13 11:24:13 +00:00
from open_webui.config import DEFAULT_RAG_TEMPLATE
2024-09-10 01:27:50 +00:00
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
2024-03-09 03:26:39 +00:00
2024-09-10 03:37:06 +00:00
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[self.embedding_function(query)],
limit=self.top_k,
)
2024-09-13 05:18:20 +00:00
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
2024-09-10 03:37:06 +00:00
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
2024-04-27 19:38:50 +00:00
def query_doc(
2024-04-22 20:49:58 +00:00
collection_name: str,
query_embedding: list[float],
2024-04-25 21:03:00 +00:00
k: int,
2024-04-22 20:49:58 +00:00
):
2024-04-14 21:55:00 +00:00
try:
2024-09-10 03:37:06 +00:00
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[query_embedding],
2024-09-10 03:37:06 +00:00
limit=k,
2024-04-27 19:38:50 +00:00
)
2024-04-25 21:03:00 +00:00
log.info(f"query_doc:result {result.ids} {result.metadatas}")
2024-04-27 19:38:50 +00:00
return result
except Exception as e:
2024-09-10 03:37:06 +00:00
print(e)
2024-04-27 19:38:50 +00:00
raise e
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
def query_doc_with_hybrid_search(
collection_name: str,
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
2024-09-12 13:50:18 +00:00
) -> dict:
2024-04-27 19:38:50 +00:00
try:
2024-09-10 03:37:06 +00:00
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
bm25_retriever = BM25Retriever.from_texts(
2024-09-13 05:21:47 +00:00
texts=result.documents[0],
metadatas=result.metadatas[0],
2024-04-27 19:38:50 +00:00
)
bm25_retriever.k = k
2024-04-25 21:03:00 +00:00
2024-09-10 03:37:06 +00:00
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
2024-04-27 19:38:50 +00:00
embedding_function=embedding_function,
2024-09-10 03:37:06 +00:00
top_k=k,
2024-04-27 19:38:50 +00:00
)
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
ensemble_retriever = EnsembleRetriever(
2024-09-10 03:37:06 +00:00
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
2024-04-27 19:38:50 +00:00
)
compressor = RerankCompressor(
embedding_function=embedding_function,
2024-04-29 17:15:58 +00:00
top_n=k,
2024-04-27 19:38:50 +00:00
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
2024-04-29 17:15:58 +00:00
log.info(
2024-11-07 07:01:10 +00:00
"query_doc_with_hybrid_search:result "
+ f'{result["metadatas"]} {result["distances"]}'
)
2024-04-14 21:55:00 +00:00
return result
except Exception as e:
raise e
2024-09-13 04:48:54 +00:00
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]:
2024-03-09 03:26:39 +00:00
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
2024-04-22 20:49:58 +00:00
combined_metadatas = []
2024-03-09 03:26:39 +00:00
for data in query_results:
combined_distances.extend(data["distances"][0])
combined_documents.extend(data["documents"][0])
2024-04-22 20:49:58 +00:00
combined_metadatas.extend(data["metadatas"][0])
2024-03-09 03:26:39 +00:00
2024-04-22 20:49:58 +00:00
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
2024-03-09 03:26:39 +00:00
# Sort the list based on distances
2024-04-26 01:00:47 +00:00
combined.sort(key=lambda x: x[0], reverse=reverse)
2024-03-09 03:26:39 +00:00
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
2024-03-09 03:26:39 +00:00
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
2024-03-09 03:26:39 +00:00
# Create the output dictionary
result = {
2024-03-09 03:26:39 +00:00
"distances": [sorted_distances],
"documents": [sorted_documents],
2024-04-22 20:49:58 +00:00
"metadatas": [sorted_metadatas],
2024-03-09 03:26:39 +00:00
}
return result
2024-03-09 03:26:39 +00:00
2024-04-27 19:38:50 +00:00
def query_collection(
2024-08-14 12:46:31 +00:00
collection_names: list[str],
2024-04-22 20:49:58 +00:00
query: str,
2024-04-27 19:38:50 +00:00
embedding_function,
k: int,
2024-09-12 13:50:18 +00:00
) -> dict:
2024-10-04 08:04:04 +00:00
2024-04-27 19:38:50 +00:00
results = []
query_embedding = embedding_function(query)
2024-04-27 19:38:50 +00:00
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
2024-10-08 18:05:57 +00:00
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
2024-04-27 19:38:50 +00:00
pass
2024-04-27 19:38:50 +00:00
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
2024-08-14 12:46:31 +00:00
collection_names: list[str],
2024-04-27 19:38:50 +00:00
query: str,
embedding_function,
2024-04-22 20:49:58 +00:00
k: int,
reranking_function,
2024-04-27 19:38:50 +00:00
r: float,
2024-09-12 13:50:18 +00:00
) -> dict:
2024-04-14 21:55:00 +00:00
results = []
2024-09-13 05:18:20 +00:00
error = False
2024-04-14 21:55:00 +00:00
for collection_name in collection_names:
try:
2024-04-27 19:38:50 +00:00
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
2024-04-27 19:38:50 +00:00
embedding_function=embedding_function,
k=k,
2024-04-22 20:49:58 +00:00
reranking_function=reranking_function,
2024-04-27 19:38:50 +00:00
r=r,
2024-04-14 21:55:00 +00:00
)
results.append(result)
except Exception as e:
log.exception(
2024-09-13 04:48:54 +00:00
"Error when querying the collection with " f"hybrid_search: {e}"
)
2024-09-13 05:18:20 +00:00
error = True
if error:
2024-09-13 04:48:54 +00:00
raise Exception(
2024-09-16 09:46:39 +00:00
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
2024-09-13 04:48:54 +00:00
)
2024-09-13 05:18:20 +00:00
2024-04-27 19:38:50 +00:00
return merge_and_sort_query_results(results, k=k, reverse=True)
2024-04-14 21:55:00 +00:00
2024-03-09 06:34:47 +00:00
def rag_template(template: str, context: str, query: str):
2024-10-13 11:24:13 +00:00
if template == "":
template = DEFAULT_RAG_TEMPLATE
2024-10-13 10:16:18 +00:00
if "[context]" not in template and "{{CONTEXT}}" not in template:
log.debug(
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
)
if "<context>" in context and "</context>" in context:
log.debug(
"WARNING: Potential prompt injection attack: the RAG "
"context contains '<context>' and '</context>'. This might be "
"nothing, or the user might be trying to hack something."
)
2024-10-13 10:16:18 +00:00
query_placeholders = []
if "[query]" in context:
2024-10-13 10:16:18 +00:00
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
2024-09-13 05:08:02 +00:00
template = template.replace("[query]", query_placeholder)
2024-10-13 10:16:18 +00:00
query_placeholders.append(query_placeholder)
if "{{QUERY}}" in context:
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
template = template.replace("{{QUERY}}", query_placeholder)
query_placeholders.append(query_placeholder)
template = template.replace("[context]", context)
template = template.replace("{{CONTEXT}}", context)
template = template.replace("[query]", query)
template = template.replace("{{QUERY}}", query)
for query_placeholder in query_placeholders:
template = template.replace(query_placeholder, query)
2024-10-13 10:16:18 +00:00
2024-03-09 06:34:47 +00:00
return template
2024-03-11 01:40:50 +00:00
2024-04-27 19:38:50 +00:00
def get_embedding_function(
2024-04-22 20:49:58 +00:00
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
embedding_batch_size,
2024-04-22 20:49:58 +00:00
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
2024-11-16 12:41:07 +00:00
# Wrapper to run the async generate_embeddings synchronously.
def sync_generate_embeddings(*args, **kwargs):
return asyncio.run(generate_embeddings(*args, **kwargs))
# Semantic expectation from the original version (using sync wrapper).
func = lambda query: sync_generate_embeddings(
2024-10-09 19:05:16 +00:00
engine=embedding_engine,
model=embedding_model,
text=query,
key=openai_key if embedding_engine == "openai" else "",
url=openai_url if embedding_engine == "openai" else "",
)
2024-10-09 18:41:35 +00:00
def generate_multiple(query, func):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
2024-10-09 18:41:35 +00:00
embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings
else:
2024-10-09 18:41:35 +00:00
return func(query)
return lambda query: generate_multiple(query, func)
2024-04-22 20:49:58 +00:00
2024-06-11 08:10:24 +00:00
def get_rag_context(
2024-06-18 21:55:18 +00:00
files,
2024-04-14 23:48:15 +00:00
messages,
2024-04-27 19:38:50 +00:00
embedding_function,
2024-04-14 23:48:15 +00:00
k,
2024-04-27 19:38:50 +00:00
reranking_function,
r,
2024-04-26 18:41:39 +00:00
hybrid_search,
2024-04-14 23:48:15 +00:00
):
2024-06-18 21:55:18 +00:00
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
2024-06-09 10:01:25 +00:00
query = get_last_user_message(messages)
2024-03-11 01:40:50 +00:00
extracted_collections = []
2024-03-11 01:40:50 +00:00
relevant_contexts = []
2024-06-18 21:55:18 +00:00
for file in files:
2024-09-29 20:52:27 +00:00
if file.get("context") == "full":
context = {
2024-10-04 05:22:22 +00:00
"documents": [[file.get("file").get("data", {}).get("content")]],
2024-09-29 20:55:53 +00:00
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
2024-09-29 20:52:27 +00:00
}
else:
context = None
2024-03-11 01:40:50 +00:00
2024-10-04 06:06:47 +00:00
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
collection_names = file.get("collection_names", [])
else:
collection_names.append(file["id"])
elif file.get("collection_name"):
collection_names.append(file["collection_name"])
elif file.get("id"):
2024-10-04 07:59:19 +00:00
if file.get("legacy"):
collection_names.append(f"{file['id']}")
else:
collection_names.append(f"file-{file['id']}")
2024-05-06 22:49:00 +00:00
2024-09-29 20:52:27 +00:00
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {file} as it has already been extracted")
continue
2024-04-14 23:48:15 +00:00
2024-09-29 20:52:27 +00:00
try:
context = None
2024-10-05 02:32:33 +00:00
if file.get("type") == "text":
2024-09-29 20:52:27 +00:00
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
2024-09-13 04:48:54 +00:00
)
2024-09-29 20:52:27 +00:00
except Exception as e:
log.exception(e)
2024-09-29 20:52:27 +00:00
extracted_collections.extend(collection_names)
2024-03-11 01:40:50 +00:00
if context:
2024-10-17 20:08:10 +00:00
if "data" in file:
del file["data"]
2024-09-29 20:52:27 +00:00
relevant_contexts.append({**context, "file": file})
2024-03-11 01:40:50 +00:00
2024-07-02 02:33:58 +00:00
contexts = []
citations = []
2024-03-11 01:40:50 +00:00
for context in relevant_contexts:
try:
if "documents" in context:
2024-10-13 10:58:51 +00:00
file_names = list(
set(
[
metadata["name"]
for metadata in context["metadatas"][0]
if metadata is not None and "name" in metadata
]
)
)
2024-07-02 02:33:58 +00:00
contexts.append(
2024-10-15 01:47:41 +00:00
((", ".join(file_names) + ":\n\n") if file_names else "")
2024-10-13 10:58:51 +00:00
+ "\n\n".join(
2024-07-02 02:33:58 +00:00
[text for text in context["documents"][0] if text is not None]
)
2024-05-06 22:49:00 +00:00
)
if "metadatas" in context:
citation = {
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
if "distances" in context and context["distances"]:
citation["distances"] = context["distances"][0]
citations.append(citation)
except Exception as e:
log.exception(e)
2024-05-06 22:14:33 +00:00
2024-10-15 01:47:41 +00:00
print("contexts", contexts)
print("citations", citations)
2024-07-02 02:33:58 +00:00
return contexts, citations
2024-04-04 18:07:42 +00:00
2024-04-25 12:49:59 +00:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
2024-04-25 18:28:31 +00:00
log.debug(f"model: {model}")
2024-04-25 12:49:59 +00:00
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
2024-07-15 09:09:05 +00:00
os.path.exists(model)
2024-04-25 12:49:59 +00:00
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
2024-07-15 09:09:05 +00:00
return model
2024-04-25 12:49:59 +00:00
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
2024-07-15 09:09:05 +00:00
return model
2024-04-25 12:49:59 +00:00
2024-11-16 12:41:07 +00:00
async def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
2024-04-14 23:15:39 +00:00
try:
r = requests.post(
2024-04-20 20:15:59 +00:00
f"{url}/embeddings",
2024-04-14 23:15:39 +00:00
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
2024-04-14 23:15:39 +00:00
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
2024-04-14 23:15:39 +00:00
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
2024-04-22 20:49:58 +00:00
2024-11-16 12:41:07 +00:00
async def generate_embeddings(
engine: str, model: str, text: Union[str, list[str]], **kwargs
):
2024-10-09 18:41:35 +00:00
if engine == "ollama":
if isinstance(text, list):
2024-11-16 12:41:07 +00:00
embeddings = await generate_ollama_batch_embeddings(
2024-10-09 18:41:35 +00:00
GenerateEmbedForm(**{"model": model, "input": text})
)
else:
2024-11-16 12:41:07 +00:00
embeddings = await generate_ollama_batch_embeddings(
2024-10-09 18:41:35 +00:00
GenerateEmbedForm(**{"model": model, "input": [text]})
)
return (
embeddings["embeddings"][0]
if isinstance(text, str)
else embeddings["embeddings"]
)
2024-10-09 18:41:35 +00:00
elif engine == "openai":
key = kwargs.get("key", "")
url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list):
2024-11-16 12:41:07 +00:00
embeddings = await generate_openai_batch_embeddings(model, text, key, url)
2024-10-09 18:41:35 +00:00
else:
2024-11-16 12:41:07 +00:00
embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
2024-10-09 18:41:35 +00:00
return embeddings[0] if isinstance(text, str) else embeddings
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
2024-08-27 22:10:27 +00:00
from langchain_core.documents import BaseDocumentCompressor, Document
class RerankCompressor(BaseDocumentCompressor):
2024-04-27 19:38:50 +00:00
embedding_function: Any
2024-04-29 17:15:58 +00:00
top_n: int
reranking_function: Any
r_score: float
class Config:
2024-09-19 15:05:49 +00:00
extra = "forbid"
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
2024-04-29 17:15:58 +00:00
reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
2024-04-27 19:38:50 +00:00
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
2024-04-29 17:15:58 +00:00
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results