This commit is contained in:
Timothy J. Baek
2024-09-10 02:27:50 +01:00
parent 1023ff8454
commit 4354f270ce
7 changed files with 138 additions and 62 deletions

View File

@@ -3,18 +3,23 @@ import os
from typing import Optional, Union
import requests
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
)
from open_webui.config import CHROMA_CLIENT
from open_webui.env import SRC_LOG_LEVELS
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
)
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -26,12 +31,10 @@ def query_doc(
k: int,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
query_embeddings = embedding_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
result = VECTOR_DB_CLIENT.query_collection(
name=collection_name,
query_embeddings=embedding_function(query),
k=k,
)
log.info(f"query_doc:result {result}")
@@ -49,7 +52,7 @@ def query_doc_with_hybrid_search(
r: float,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts(