From 36ce157907b8800bed25fd671d60359ce97c93c7 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 14 Apr 2024 18:47:45 -0400 Subject: [PATCH] fix: integration --- backend/apps/ollama/main.py | 5 +++++ backend/apps/rag/main.py | 27 ++++++++++++++++++++------- backend/apps/rag/utils.py | 3 +++ 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 0132179f5..387ff05da 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -658,6 +658,9 @@ def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): + + log.info("generate_ollama_embeddings", form_data) + if url_idx == None: model = form_data.model @@ -685,6 +688,8 @@ def generate_ollama_embeddings( data = r.json() + log.info("generate_ollama_embeddings", data) + if "embedding" in data: return data["embedding"] else: diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index e1a5e6eb8..976c7735b 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -39,7 +39,7 @@ import uuid import json -from apps.ollama.main import generate_ollama_embeddings +from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm from apps.web.models.documents import ( Documents, @@ -277,7 +277,12 @@ def query_doc_handler( try: if app.state.RAG_EMBEDDING_ENGINE == "ollama": query_embeddings = generate_ollama_embeddings( - {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) ) return query_embeddings_doc( @@ -314,7 +319,12 @@ def query_collection_handler( try: if app.state.RAG_EMBEDDING_ENGINE == "ollama": query_embeddings = generate_ollama_embeddings( - {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} + GenerateEmbeddingsForm( + **{ + "model": app.state.RAG_EMBEDDING_MODEL, + "prompt": form_data.query, + } + ) ) return query_embeddings_collection( @@ -373,6 +383,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b docs = text_splitter.split_documents(data) if len(docs) > 0: + log.info("store_data_in_vector_db", "store_docs_in_vector_db") return store_docs_in_vector_db(docs, collection_name, overwrite), None else: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) @@ -390,9 +401,8 @@ def store_text_in_vector_db( return store_docs_in_vector_db(docs, collection_name, overwrite) -async def store_docs_in_vector_db( - docs, collection_name, overwrite: bool = False -) -> bool: +def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: + log.info("store_docs_in_vector_db", docs, collection_name) texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] @@ -413,13 +423,16 @@ async def store_docs_in_vector_db( metadatas=metadatas, embeddings=[ generate_ollama_embeddings( - {"model": RAG_EMBEDDING_MODEL, "prompt": text} + GenerateEmbeddingsForm( + **{"model": RAG_EMBEDDING_MODEL, "prompt": text} + ) ) for text in texts ], ): collection.add(*batch) else: + collection = CHROMA_CLIENT.create_collection( name=collection_name, embedding_function=app.state.sentence_transformer_ef, diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 301c63b99..17d8e4a9a 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -32,6 +32,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): def query_embeddings_doc(collection_name: str, query_embeddings, k: int): try: # if you use docker use the model from the environment variable + log.info("query_embeddings_doc", query_embeddings) collection = CHROMA_CLIENT.get_collection( name=collection_name, ) @@ -117,6 +118,8 @@ def query_collection( def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): results = [] + log.info("query_embeddings_collection", query_embeddings) + for collection_name in collection_names: try: collection = CHROMA_CLIENT.get_collection(name=collection_name)