This commit is contained in:
Timothy Jaeryang Baek 2024-12-30 16:55:29 -08:00
parent 9b56b64cfa
commit fd0170c179

View File

@ -1,8 +1,9 @@
import logging import logging
import os import os
import heapq import uuid
from typing import Optional, Union from typing import Optional, Union
import asyncio
import requests import requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -33,6 +34,8 @@ class VectorSearchRetriever(BaseRetriever):
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]: ) -> list[Document]:
result = VECTOR_DB_CLIENT.search( result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name, collection_name=self.collection_name,
@ -44,12 +47,15 @@ class VectorSearchRetriever(BaseRetriever):
metadatas = result.metadatas[0] metadatas = result.metadatas[0]
documents = result.documents[0] documents = result.documents[0]
return [ results = []
Document( for idx in range(len(ids)):
metadata=metadatas[idx], results.append(
page_content=documents[idx], Document(
) for idx in range(len(ids)) metadata=metadatas[idx],
] page_content=documents[idx],
)
)
return results
def query_doc( def query_doc(
@ -58,14 +64,16 @@ def query_doc(
k: int, k: int,
): ):
try: try:
if result := VECTOR_DB_CLIENT.search( result = VECTOR_DB_CLIENT.search(
collection_name=collection_name, collection_name=collection_name,
vectors=[query_embedding], vectors=[query_embedding],
limit=k, limit=k,
): )
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}") log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result return result
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
@ -127,38 +135,44 @@ def query_doc_with_hybrid_search(
def merge_and_sort_query_results( def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]: ) -> list[dict]:
if not query_results: # Initialize lists to store combined data
return { combined_distances = []
"distances": [[]], combined_documents = []
"documents": [[]], combined_metadatas = []
"metadatas": [[]],
}
combined = ( for data in query_results:
(data.get("distances", [float('inf')])[0], combined_distances.extend(data["distances"][0])
data.get("documents", [None])[0], combined_documents.extend(data["documents"][0])
data.get("metadatas", [{}])[0]) combined_metadatas.extend(data["metadatas"][0])
for data in query_results
)
if reverse: # Create a list of tuples (distance, document, metadata)
top_k = heapq.nlargest(k, combined, key=lambda x: x[0]) combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else: else:
top_k = heapq.nsmallest(k, combined, key=lambda x: x[0]) # Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
if not top_k: # Slicing the lists to include only k elements
return { sorted_distances = list(sorted_distances)[:k]
"distances": [[]], sorted_documents = list(sorted_documents)[:k]
"documents": [[]], sorted_metadatas = list(sorted_metadatas)[:k]
"metadatas": [[]],
} # Create the output dictionary
else: result = {
sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k) "distances": [sorted_distances],
return { "documents": [sorted_documents],
"distances": [sorted_distances], "metadatas": [sorted_metadatas],
"documents": [sorted_documents], }
"metadatas": [sorted_metadatas],
} return result
def query_collection( def query_collection(
@ -171,18 +185,19 @@ def query_collection(
for query in queries: for query in queries:
query_embedding = embedding_function(query) query_embedding = embedding_function(query)
for collection_name in collection_names: for collection_name in collection_names:
if not collection_name: if collection_name:
continue try:
result = query_doc(
try: collection_name=collection_name,
if result := query_doc( k=k,
collection_name=collection_name, query_embedding=query_embedding,
k=k, )
query_embedding=query_embedding, if result is not None:
): results.append(result.model_dump())
results.append(result.model_dump()) except Exception as e:
except Exception as e: log.exception(f"Error when querying the collection: {e}")
log.exception(f"Error when querying the collection: {e}") else:
pass
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
@ -198,8 +213,8 @@ def query_collection_with_hybrid_search(
results = [] results = []
error = False error = False
for collection_name in collection_names: for collection_name in collection_names:
for query in queries: try:
try: for query in queries:
result = query_doc_with_hybrid_search( result = query_doc_with_hybrid_search(
collection_name=collection_name, collection_name=collection_name,
query=query, query=query,
@ -209,11 +224,11 @@ def query_collection_with_hybrid_search(
r=r, r=r,
) )
results.append(result) results.append(result)
except Exception as e: except Exception as e:
log.exception( log.exception(
"Error when querying the collection with " f"hybrid_search: {e}" "Error when querying the collection with " f"hybrid_search: {e}"
) )
error = True error = True
if error: if error:
raise Exception( raise Exception(
@ -244,10 +259,10 @@ def get_embedding_function(
def generate_multiple(query, func): def generate_multiple(query, func):
if isinstance(query, list): if isinstance(query, list):
return [ embeddings = []
func(query[i : i + embedding_batch_size]) for i in range(0, len(query), embedding_batch_size):
for i in range(0, len(query), embedding_batch_size) embeddings.extend(func(query[i : i + embedding_batch_size]))
] return embeddings
else: else:
return func(query) return func(query)
@ -421,27 +436,26 @@ def generate_openai_batch_embeddings(
def generate_ollama_batch_embeddings( def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = "" model: str, texts: list[str], url: str, key: str = ""
) -> Optional[list[list[float]]]: ) -> Optional[list[list[float]]]:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
try: try:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
r.raise_for_status() r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
data = r.json()
if 'embeddings' not in data:
raise "Something went wrong :/"
return data['embeddings']
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "") url = kwargs.get("url", "")