This commit is contained in:
Timothy J. Baek
2024-09-10 04:37:06 +01:00
parent d5f13dd9e0
commit 522afbb0a0
7 changed files with 240 additions and 127 deletions

View File

@@ -24,6 +24,44 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[self.embedding_function(query)],
limit=self.top_k,
)
ids = result["ids"][0]
metadatas = result["metadatas"][0]
documents = result["documents"][0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
def query_doc(
collection_name: str,
query: str,
@@ -31,15 +69,18 @@ def query_doc(
k: int,
):
try:
result = VECTOR_DB_CLIENT.query_collection(
name=collection_name,
query_embeddings=embedding_function(query),
k=k,
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[embedding_function(query)],
limit=k,
)
print("result", result)
log.info(f"query_doc:result {result}")
return result
except Exception as e:
print(e)
raise e
@@ -52,25 +93,23 @@ def query_doc_with_hybrid_search(
r: float,
):
try:
collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
texts=result.documents,
metadatas=result.metadatas,
)
bm25_retriever.k = k
chroma_retriever = ChromaRetriever(
collection=collection,
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
embedding_function=embedding_function,
top_n=k,
top_k=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k,
@@ -394,45 +433,6 @@ def generate_openai_batch_embeddings(
return None
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class ChromaRetriever(BaseRetriever):
collection: Any
embedding_function: Any
top_n: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
query_embeddings = self.embedding_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.top_n,
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
import operator
from typing import Optional, Sequence