Merge pull request #8212 from ashm-dev/main
Some checks are pending
Deploy to HuggingFace Spaces / check-secret (push) Waiting to run
Deploy to HuggingFace Spaces / deploy (push) Blocked by required conditions
Create and publish Docker images with specific build args / build-main-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-main-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-cuda-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/amd64) (push) Waiting to run
Create and publish Docker images with specific build args / build-ollama-image (linux/arm64) (push) Waiting to run
Create and publish Docker images with specific build args / merge-main-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-cuda-images (push) Blocked by required conditions
Create and publish Docker images with specific build args / merge-ollama-images (push) Blocked by required conditions
Python CI / Format Backend (3.11) (push) Waiting to run
Frontend Build / Format & Build Frontend (push) Waiting to run
Frontend Build / Frontend Unit Tests (push) Waiting to run
Integration Test / Run Cypress Integration Tests (push) Waiting to run
Integration Test / Run Migration Tests (push) Waiting to run

feat: Small optimization
This commit is contained in:
Timothy Jaeryang Baek 2024-12-30 16:00:18 -08:00 committed by GitHub
commit 9b56b64cfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,8 @@
import logging
import os
import uuid
import heapq
from typing import Optional, Union
import asyncio
import requests
from huggingface_hub import snapshot_download
@ -34,8 +33,6 @@ class VectorSearchRetriever(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
@ -47,15 +44,12 @@ class VectorSearchRetriever(BaseRetriever):
metadatas = result.metadatas[0]
documents = result.documents[0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
return [
Document(
metadata=metadatas[idx],
page_content=documents[idx],
) for idx in range(len(ids))
]
def query_doc(
@ -64,16 +58,14 @@ def query_doc(
k: int,
):
try:
result = VECTOR_DB_CLIENT.search(
if result := VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[query_embedding],
limit=k,
)
if result:
):
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
return result
except Exception as e:
print(e)
raise e
@ -135,44 +127,38 @@ def query_doc_with_hybrid_search(
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
combined_metadatas = []
for data in query_results:
combined_distances.extend(data["distances"][0])
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
if not query_results:
return {
"distances": [[]],
"documents": [[]],
"metadatas": [[]],
}
combined = (
(data.get("distances", [float('inf')])[0],
data.get("documents", [None])[0],
data.get("metadatas", [{}])[0])
for data in query_results
)
if reverse:
top_k = heapq.nlargest(k, combined, key=lambda x: x[0])
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
# Create the output dictionary
result = {
"distances": [sorted_distances],
"documents": [sorted_documents],
"metadatas": [sorted_metadatas],
}
return result
top_k = heapq.nsmallest(k, combined, key=lambda x: x[0])
if not top_k:
return {
"distances": [[]],
"documents": [[]],
"metadatas": [[]],
}
else:
sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k)
return {
"distances": [sorted_distances],
"documents": [sorted_documents],
"metadatas": [sorted_metadatas],
}
def query_collection(
@ -185,19 +171,18 @@ def query_collection(
for query in queries:
query_embedding = embedding_function(query)
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
if not collection_name:
continue
try:
if result := query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
):
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
return merge_and_sort_query_results(results, k=k)
@ -213,8 +198,8 @@ def query_collection_with_hybrid_search(
results = []
error = False
for collection_name in collection_names:
try:
for query in queries:
for query in queries:
try:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
@ -224,11 +209,11 @@ def query_collection_with_hybrid_search(
r=r,
)
results.append(result)
except Exception as e:
log.exception(
"Error when querying the collection with " f"hybrid_search: {e}"
)
error = True
except Exception as e:
log.exception(
"Error when querying the collection with " f"hybrid_search: {e}"
)
error = True
if error:
raise Exception(
@ -259,10 +244,10 @@ def get_embedding_function(
def generate_multiple(query, func):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(func(query[i : i + embedding_batch_size]))
return embeddings
return [
func(query[i : i + embedding_batch_size])
for i in range(0, len(query), embedding_batch_size)
]
else:
return func(query)
@ -436,25 +421,26 @@ def generate_openai_batch_embeddings(
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = ""
) -> Optional[list[list[float]]]:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
try:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
data = r.json()
if 'embeddings' not in data:
raise "Something went wrong :/"
return data['embeddings']
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):