From bc3dd34d8b7980668aa97041d804a84bc3e24e65 Mon Sep 17 00:00:00 2001 From: Jannik Streidl Date: Sun, 18 Feb 2024 09:17:43 +0100 Subject: [PATCH] collection query fix --- backend/apps/rag/main.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index defe10f95..8a5a12d39 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -29,11 +29,13 @@ from langchain_community.document_loaders import ( from langchain.text_splitter import RecursiveCharacterTextSplitter + from pydantic import BaseModel from typing import Optional import uuid + from utils.misc import calculate_sha256, calculate_sha256_string from utils.utils import get_current_user, get_admin_user from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP @@ -113,12 +115,12 @@ def query_doc( # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=form_data.collection_name, - embedding_function=sentence_transformer_ef + embedding_function=sentence_transformer_ef, ) else: # for local development use the default model - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result @@ -191,16 +193,16 @@ def query_collection( for collection_name in form_data.collection_names: try: if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - embedding_function=sentence_transformer_ef + name=collection_name, + embedding_function=sentence_transformer_ef, ) else: # for local development use the default model collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - ) + name=collection_name, + ) result = collection.query( query_texts=[form_data.query], n_results=form_data.k