open-webui/backend/open_webui/apps/rag/utils.py

492 lines
14 KiB
Python
Raw Normal View History

import logging
2024-08-27 22:10:27 +00:00
import os
from typing import Optional, Union
2024-03-09 03:26:39 +00:00
2024-08-27 22:10:27 +00:00
import requests
2024-09-10 01:27:50 +00:00
2024-04-25 12:49:59 +00:00
from huggingface_hub import snapshot_download
2024-08-27 22:10:27 +00:00
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
2024-08-27 22:10:27 +00:00
from langchain_core.documents import Document
2024-09-10 01:27:50 +00:00
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
)
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
2024-04-14 23:48:15 +00:00
2024-09-10 01:27:50 +00:00
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
2024-03-09 03:26:39 +00:00
2024-04-27 19:38:50 +00:00
def query_doc(
2024-04-22 20:49:58 +00:00
collection_name: str,
query: str,
2024-04-27 19:38:50 +00:00
embedding_function,
2024-04-25 21:03:00 +00:00
k: int,
2024-04-22 20:49:58 +00:00
):
2024-04-14 21:55:00 +00:00
try:
2024-09-10 01:27:50 +00:00
result = VECTOR_DB_CLIENT.query_collection(
name=collection_name,
query_embeddings=embedding_function(query),
k=k,
2024-04-27 19:38:50 +00:00
)
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
log.info(f"query_doc:result {result}")
return result
except Exception as e:
raise e
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
def query_doc_with_hybrid_search(
collection_name: str,
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
2024-04-27 19:38:50 +00:00
):
try:
2024-09-10 01:27:50 +00:00
collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
2024-04-27 19:38:50 +00:00
documents = collection.get() # get all documents
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
)
bm25_retriever.k = k
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
chroma_retriever = ChromaRetriever(
collection=collection,
embedding_function=embedding_function,
top_n=k,
)
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embedding_function=embedding_function,
2024-04-29 17:15:58 +00:00
top_n=k,
2024-04-27 19:38:50 +00:00
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
2024-04-25 21:03:00 +00:00
2024-04-27 19:38:50 +00:00
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
2024-04-29 17:15:58 +00:00
2024-04-27 19:38:50 +00:00
log.info(f"query_doc_with_hybrid_search:result {result}")
2024-04-14 21:55:00 +00:00
return result
except Exception as e:
raise e
2024-04-26 01:00:47 +00:00
def merge_and_sort_query_results(query_results, k, reverse=False):
2024-03-09 03:26:39 +00:00
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
2024-04-22 20:49:58 +00:00
combined_metadatas = []
2024-03-09 03:26:39 +00:00
for data in query_results:
combined_distances.extend(data["distances"][0])
combined_documents.extend(data["documents"][0])
2024-04-22 20:49:58 +00:00
combined_metadatas.extend(data["metadatas"][0])
2024-03-09 03:26:39 +00:00
2024-04-22 20:49:58 +00:00
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
2024-03-09 03:26:39 +00:00
# Sort the list based on distances
2024-04-26 01:00:47 +00:00
combined.sort(key=lambda x: x[0], reverse=reverse)
2024-03-09 03:26:39 +00:00
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
2024-03-09 03:26:39 +00:00
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
2024-03-09 03:26:39 +00:00
# Create the output dictionary
result = {
2024-03-09 03:26:39 +00:00
"distances": [sorted_distances],
"documents": [sorted_documents],
2024-04-22 20:49:58 +00:00
"metadatas": [sorted_metadatas],
2024-03-09 03:26:39 +00:00
}
return result
2024-03-09 03:26:39 +00:00
2024-04-27 19:38:50 +00:00
def query_collection(
2024-08-14 12:46:31 +00:00
collection_names: list[str],
2024-04-22 20:49:58 +00:00
query: str,
2024-04-27 19:38:50 +00:00
embedding_function,
k: int,
):
results = []
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
query=query,
k=k,
embedding_function=embedding_function,
)
results.append(result)
except Exception:
pass
else:
2024-04-27 19:38:50 +00:00
pass
2024-04-27 19:38:50 +00:00
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
2024-08-14 12:46:31 +00:00
collection_names: list[str],
2024-04-27 19:38:50 +00:00
query: str,
embedding_function,
2024-04-22 20:49:58 +00:00
k: int,
reranking_function,
2024-04-27 19:38:50 +00:00
r: float,
2024-03-09 03:26:39 +00:00
):
2024-04-14 21:55:00 +00:00
results = []
for collection_name in collection_names:
try:
2024-04-27 19:38:50 +00:00
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
2024-04-27 19:38:50 +00:00
embedding_function=embedding_function,
k=k,
2024-04-22 20:49:58 +00:00
reranking_function=reranking_function,
2024-04-27 19:38:50 +00:00
r=r,
2024-04-14 21:55:00 +00:00
)
results.append(result)
2024-08-14 12:38:19 +00:00
except Exception:
2024-04-14 21:55:00 +00:00
pass
2024-04-27 19:38:50 +00:00
return merge_and_sort_query_results(results, k=k, reverse=True)
2024-04-14 21:55:00 +00:00
2024-03-09 06:34:47 +00:00
def rag_template(template: str, context: str, query: str):
2024-03-15 20:34:52 +00:00
template = template.replace("[context]", context)
template = template.replace("[query]", query)
2024-03-09 06:34:47 +00:00
return template
2024-03-11 01:40:50 +00:00
2024-04-27 19:38:50 +00:00
def get_embedding_function(
2024-04-22 20:49:58 +00:00
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
batch_size,
2024-04-22 20:49:58 +00:00
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
2024-04-22 20:49:58 +00:00
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
def generate_multiple(query, f):
if isinstance(query, list):
if embedding_engine == "openai":
embeddings = []
for i in range(0, len(query), batch_size):
embeddings.extend(f(query[i : i + batch_size]))
return embeddings
else:
return [f(q) for q in query]
else:
return f(query)
return lambda query: generate_multiple(query, func)
2024-04-22 20:49:58 +00:00
2024-06-11 08:10:24 +00:00
def get_rag_context(
2024-06-18 21:55:18 +00:00
files,
2024-04-14 23:48:15 +00:00
messages,
2024-04-27 19:38:50 +00:00
embedding_function,
2024-04-14 23:48:15 +00:00
k,
2024-04-27 19:38:50 +00:00
reranking_function,
r,
2024-04-26 18:41:39 +00:00
hybrid_search,
2024-04-14 23:48:15 +00:00
):
2024-06-18 21:55:18 +00:00
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
2024-06-09 10:01:25 +00:00
query = get_last_user_message(messages)
2024-03-11 01:40:50 +00:00
extracted_collections = []
2024-03-11 01:40:50 +00:00
relevant_contexts = []
2024-06-18 21:55:18 +00:00
for file in files:
2024-03-11 01:40:50 +00:00
context = None
2024-05-06 22:49:00 +00:00
collection_names = (
2024-06-18 21:55:18 +00:00
file["collection_names"]
if file["type"] == "collection"
else [file["collection_name"]] if file["collection_name"] else []
2024-05-06 22:49:00 +00:00
)
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
2024-06-18 21:55:18 +00:00
log.debug(f"skipping {file} as it has already been extracted")
continue
2024-04-14 23:48:15 +00:00
try:
2024-06-18 21:55:18 +00:00
if file["type"] == "text":
context = file["content"]
2024-03-11 01:40:50 +00:00
else:
2024-04-27 19:38:50 +00:00
if hybrid_search:
context = query_collection_with_hybrid_search(
2024-05-06 22:49:00 +00:00
collection_names=collection_names,
2024-04-27 19:38:50 +00:00
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
else:
context = query_collection(
2024-05-06 22:49:00 +00:00
collection_names=collection_names,
2024-04-27 19:38:50 +00:00
query=query,
embedding_function=embedding_function,
k=k,
)
2024-03-11 01:40:50 +00:00
except Exception as e:
log.exception(e)
2024-03-11 01:40:50 +00:00
context = None
if context:
2024-06-18 21:55:18 +00:00
relevant_contexts.append({**context, "source": file})
2024-05-06 22:49:00 +00:00
extracted_collections.extend(collection_names)
2024-03-11 01:40:50 +00:00
2024-07-02 02:33:58 +00:00
contexts = []
citations = []
2024-07-02 02:33:58 +00:00
2024-03-11 01:40:50 +00:00
for context in relevant_contexts:
try:
if "documents" in context:
2024-07-02 02:33:58 +00:00
contexts.append(
"\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
2024-05-06 22:49:00 +00:00
)
if "metadatas" in context:
citations.append(
{
2024-05-06 22:49:00 +00:00
"source": context["source"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
)
except Exception as e:
log.exception(e)
2024-05-06 22:14:33 +00:00
2024-07-02 02:33:58 +00:00
return contexts, citations
2024-04-04 18:07:42 +00:00
2024-04-25 12:49:59 +00:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
2024-04-25 18:28:31 +00:00
log.debug(f"model: {model}")
2024-04-25 12:49:59 +00:00
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
2024-07-15 09:09:05 +00:00
os.path.exists(model)
2024-04-25 12:49:59 +00:00
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
2024-07-15 09:09:05 +00:00
return model
2024-04-25 12:49:59 +00:00
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
2024-07-15 09:09:05 +00:00
return model
2024-04-25 12:49:59 +00:00
2024-04-14 23:15:39 +00:00
def generate_openai_embeddings(
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
2024-04-14 23:15:39 +00:00
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
2024-04-14 23:15:39 +00:00
try:
r = requests.post(
2024-04-20 20:15:59 +00:00
f"{url}/embeddings",
2024-04-14 23:15:39 +00:00
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
2024-04-14 23:15:39 +00:00
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
2024-04-14 23:15:39 +00:00
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
2024-04-22 20:49:58 +00:00
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
2024-08-27 22:10:27 +00:00
from langchain_core.retrievers import BaseRetriever
2024-04-22 20:49:58 +00:00
class ChromaRetriever(BaseRetriever):
collection: Any
2024-04-27 19:38:50 +00:00
embedding_function: Any
top_n: int
2024-04-22 20:49:58 +00:00
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
2024-08-14 12:46:31 +00:00
) -> list[Document]:
2024-04-27 19:38:50 +00:00
query_embeddings = self.embedding_function(query)
2024-04-22 20:49:58 +00:00
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.top_n,
2024-04-22 20:49:58 +00:00
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
2024-04-29 17:15:58 +00:00
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
2024-04-22 20:49:58 +00:00
)
2024-04-29 17:15:58 +00:00
return results
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
2024-08-27 22:10:27 +00:00
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra
class RerankCompressor(BaseDocumentCompressor):
2024-04-27 19:38:50 +00:00
embedding_function: Any
2024-04-29 17:15:58 +00:00
top_n: int
reranking_function: Any
r_score: float
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
2024-04-29 17:15:58 +00:00
reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
2024-04-27 19:38:50 +00:00
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
2024-04-29 17:15:58 +00:00
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results