diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index a33a29659..19019ffbd 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -391,16 +391,16 @@ def query_doc_handler( return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, + reranking_function=app.state.sentence_transformer_rf, r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, ) else: return query_doc( collection_name=form_data.collection_name, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, ) except Exception as e: @@ -429,16 +429,16 @@ def query_collection_handler( return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, - reranking_function=app.state.sentence_transformer_rf, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, + reranking_function=app.state.sentence_transformer_rf, r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, ) else: return query_collection( collection_names=form_data.collection_names, query=form_data.query, - embeddings_function=app.state.EMBEDDING_FUNCTION, + embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.TOP_K, ) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index eb9d5c84b..10f1f7bed 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -35,6 +35,7 @@ def query_doc( try: collection = CHROMA_CLIENT.get_collection(name=collection_name) query_embeddings = embedding_function(query) + result = collection.query( query_embeddings=[query_embeddings], n_results=k, @@ -76,9 +77,9 @@ def query_doc_with_hybrid_search( compressor = RerankCompressor( embedding_function=embedding_function, + top_n=k, reranking_function=reranking_function, r_score=r, - top_n=k, ) compression_retriever = ContextualCompressionRetriever( @@ -91,6 +92,7 @@ def query_doc_with_hybrid_search( "documents": [[d.page_content for d in result]], "metadatas": [[d.metadata for d in result]], } + log.info(f"query_doc_with_hybrid_search:result {result}") return result except Exception as e: @@ -167,7 +169,6 @@ def query_collection_with_hybrid_search( reranking_function, r: float, ): - results = [] for collection_name in collection_names: try: @@ -182,7 +183,6 @@ def query_collection_with_hybrid_search( results.append(result) except: pass - return merge_and_sort_query_results(results, k=k, reverse=True) @@ -443,13 +443,15 @@ class ChromaRetriever(BaseRetriever): metadatas = results["metadatas"][0] documents = results["documents"][0] - return [ - Document( - metadata=metadatas[idx], - page_content=documents[idx], + results = [] + for idx in range(len(ids)): + results.append( + Document( + metadata=metadatas[idx], + page_content=documents[idx], + ) ) - for idx in range(len(ids)) - ] + return results import operator @@ -465,9 +467,9 @@ from sentence_transformers import util class RerankCompressor(BaseDocumentCompressor): embedding_function: Any + top_n: int reranking_function: Any r_score: float - top_n: int class Config: extra = Extra.forbid @@ -479,7 +481,9 @@ class RerankCompressor(BaseDocumentCompressor): query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: - if self.reranking_function: + reranking = self.reranking_function is not None + + if reranking: scores = self.reranking_function.predict( [(query, doc.page_content) for doc in documents] ) @@ -496,9 +500,7 @@ class RerankCompressor(BaseDocumentCompressor): (d, s) for d, s in docs_with_scores if s >= self.r_score ] - reverse = self.reranking_function is not None - result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse) - + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) final_results = [] for doc, doc_score in result[: self.top_n]: metadata = doc.metadata diff --git a/src/lib/components/admin/UserChatsModal.svelte b/src/lib/components/admin/UserChatsModal.svelte index 34998834b..67fa367cd 100644 --- a/src/lib/components/admin/UserChatsModal.svelte +++ b/src/lib/components/admin/UserChatsModal.svelte @@ -133,7 +133,10 @@ {/each} --> {:else} -