diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 6c064fe81..5f0ef21f1 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -12,11 +12,16 @@ from typing import Iterator, Optional, Sequence, Union import requests import validators + +from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +from open_webui.apps.rag.search.main import SearchResult from open_webui.apps.rag.search.brave import search_brave from open_webui.apps.rag.search.duckduckgo import search_duckduckgo from open_webui.apps.rag.search.google_pse import search_google_pse from open_webui.apps.rag.search.jina_search import search_jina -from open_webui.apps.rag.search.main import SearchResult from open_webui.apps.rag.search.searchapi import search_searchapi from open_webui.apps.rag.search.searxng import search_searxng from open_webui.apps.rag.search.serper import search_serper @@ -33,10 +38,8 @@ from open_webui.apps.rag.utils import ( ) from open_webui.apps.webui.models.documents import DocumentForm, Documents from open_webui.apps.webui.models.files import Files -from chromadb.utils.batch_utils import create_batches from open_webui.config import ( BRAVE_SEARCH_API_KEY, - CHROMA_CLIENT, CHUNK_OVERLAP, CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, @@ -84,9 +87,17 @@ from open_webui.config import ( AppConfig, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status -from fastapi.middleware.cors import CORSMiddleware +from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE +from open_webui.utils.misc import ( + calculate_sha256, + calculate_sha256_string, + extract_folders_after_data_docs, + sanitize_filename, +) +from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT + +from chromadb.utils.batch_utils import create_batches from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( BSHTMLLoader, @@ -105,14 +116,6 @@ from langchain_community.document_loaders import ( YoutubeLoader, ) from langchain_core.documents import Document -from pydantic import BaseModel -from open_webui.utils.misc import ( - calculate_sha256, - calculate_sha256_string, - extract_folders_after_data_docs, - sanitize_filename, -) -from open_webui.utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -143,13 +146,11 @@ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SI app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_TEMPLATE = RAG_TEMPLATE - app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES - app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.YOUTUBE_LOADER_TRANSLATION = None @@ -998,12 +999,12 @@ def store_docs_in_vector_db( try: if overwrite: - for collection in CHROMA_CLIENT.list_collections(): + for collection in VECTOR_DB_CLIENT.list_collections(): if collection_name == collection.name: log.info(f"deleting existing collection {collection_name}") - CHROMA_CLIENT.delete_collection(name=collection_name) + VECTOR_DB_CLIENT.delete_collection(name=collection_name) - collection = CHROMA_CLIENT.create_collection(name=collection_name) + collection = VECTOR_DB_CLIENT.create_collection(name=collection_name) embedding_func = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, @@ -1018,7 +1019,7 @@ def store_docs_in_vector_db( embeddings = embedding_func(embedding_texts) for batch in create_batches( - api=CHROMA_CLIENT, + api=VECTOR_DB_CLIENT, ids=[str(uuid.uuid4()) for _ in texts], metadatas=metadatas, embeddings=embeddings, @@ -1396,7 +1397,7 @@ def scan_docs_dir(user=Depends(get_admin_user)): @app.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): - CHROMA_CLIENT.reset() + VECTOR_DB_CLIENT.reset() @app.post("/reset/uploads") @@ -1437,7 +1438,7 @@ def reset(user=Depends(get_admin_user)) -> bool: log.error("Failed to delete %s. Reason: %s" % (file_path, e)) try: - CHROMA_CLIENT.reset() + VECTOR_DB_CLIENT.reset() except Exception as e: log.exception(e) diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py index 2bf8a02e4..035fefc60 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/rag/utils.py @@ -3,18 +3,23 @@ import os from typing import Optional, Union import requests -from open_webui.apps.ollama.main import ( - GenerateEmbeddingsForm, - generate_ollama_embeddings, -) -from open_webui.config import CHROMA_CLIENT -from open_webui.env import SRC_LOG_LEVELS + from huggingface_hub import snapshot_download from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document + + +from open_webui.apps.ollama.main import ( + GenerateEmbeddingsForm, + generate_ollama_embeddings, +) +from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message +from open_webui.env import SRC_LOG_LEVELS + + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -26,12 +31,10 @@ def query_doc( k: int, ): try: - collection = CHROMA_CLIENT.get_collection(name=collection_name) - query_embeddings = embedding_function(query) - - result = collection.query( - query_embeddings=[query_embeddings], - n_results=k, + result = VECTOR_DB_CLIENT.query_collection( + name=collection_name, + query_embeddings=embedding_function(query), + k=k, ) log.info(f"query_doc:result {result}") @@ -49,7 +52,7 @@ def query_doc_with_hybrid_search( r: float, ): try: - collection = CHROMA_CLIENT.get_collection(name=collection_name) + collection = VECTOR_DB_CLIENT.get_collection(name=collection_name) documents = collection.get() # get all documents bm25_retriever = BM25Retriever.from_texts( diff --git a/backend/open_webui/apps/rag/vector/connector.py b/backend/open_webui/apps/rag/vector/connector.py new file mode 100644 index 000000000..d7ca615bf --- /dev/null +++ b/backend/open_webui/apps/rag/vector/connector.py @@ -0,0 +1,4 @@ +from open_webui.apps.rag.vector.dbs.chroma import Chroma +from open_webui.config import VECTOR_DB + +VECTOR_DB_CLIENT = Chroma() diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py new file mode 100644 index 000000000..1fd560642 --- /dev/null +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -0,0 +1,58 @@ +import chromadb +from chromadb import Settings + +from open_webui.config import ( + CHROMA_DATA_PATH, + CHROMA_HTTP_HOST, + CHROMA_HTTP_PORT, + CHROMA_HTTP_HEADERS, + CHROMA_HTTP_SSL, + CHROMA_TENANT, + CHROMA_DATABASE, +) + + +class Chroma: + def __init__(self): + if CHROMA_HTTP_HOST != "": + self.client = chromadb.HttpClient( + host=CHROMA_HTTP_HOST, + port=CHROMA_HTTP_PORT, + headers=CHROMA_HTTP_HEADERS, + ssl=CHROMA_HTTP_SSL, + tenant=CHROMA_TENANT, + database=CHROMA_DATABASE, + settings=Settings(allow_reset=True, anonymized_telemetry=False), + ) + else: + self.client = chromadb.PersistentClient( + path=CHROMA_DATA_PATH, + settings=Settings(allow_reset=True, anonymized_telemetry=False), + tenant=CHROMA_TENANT, + database=CHROMA_DATABASE, + ) + + def query_collection(self, name, query_embeddings, k): + collection = self.client.get_collection(name=name) + if collection: + result = collection.query( + query_embeddings=[query_embeddings], + n_results=k, + ) + return result + return None + + def list_collections(self): + return self.client.list_collections() + + def create_collection(self, name): + return self.client.create_collection(name=name) + + def get_or_create_collection(self, name): + return self.client.get_or_create_collection(name=name) + + def delete_collection(self, name): + return self.client.delete_collection(name=name) + + def reset(self): + return self.client.reset() diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/open_webui/apps/webui/routers/memories.py b/backend/open_webui/apps/webui/routers/memories.py index 914b69e7e..1b44063e7 100644 --- a/backend/open_webui/apps/webui/routers/memories.py +++ b/backend/open_webui/apps/webui/routers/memories.py @@ -1,12 +1,13 @@ +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel import logging from typing import Optional from open_webui.apps.webui.models.memories import Memories, MemoryModel -from open_webui.config import CHROMA_CLIENT -from open_webui.env import SRC_LOG_LEVELS -from fastapi import APIRouter, Depends, HTTPException, Request -from pydantic import BaseModel +from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.utils import get_verified_user +from open_webui.env import SRC_LOG_LEVELS + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -51,7 +52,9 @@ async def add_memory( memory = Memories.insert_new_memory(user.id, form_data.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) - collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + collection = VECTOR_DB_CLIENT.get_or_create_collection( + name=f"user-memory-{user.id}" + ) collection.upsert( documents=[memory.content], ids=[memory.id], @@ -77,7 +80,9 @@ async def query_memory( request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) ): query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + collection = VECTOR_DB_CLIENT.get_or_create_collection( + name=f"user-memory-{user.id}" + ) results = collection.query( query_embeddings=[query_embedding], @@ -94,8 +99,10 @@ async def query_memory( async def reset_memory_from_vector_db( request: Request, user=Depends(get_verified_user) ): - CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") - collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") + VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") + collection = VECTOR_DB_CLIENT.get_or_create_collection( + name=f"user-memory-{user.id}" + ) memories = Memories.get_memories_by_user_id(user.id) for memory in memories: @@ -119,7 +126,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)): if result: try: - CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") + VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") except Exception as e: log.error(e) return True @@ -145,7 +152,7 @@ async def update_memory_by_id( if form_data.content is not None: memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) - collection = CHROMA_CLIENT.get_or_create_collection( + collection = VECTOR_DB_CLIENT.get_or_create_collection( name=f"user-memory-{user.id}" ) collection.upsert( @@ -170,7 +177,7 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: - collection = CHROMA_CLIENT.get_or_create_collection( + collection = VECTOR_DB_CLIENT.get_or_create_collection( name=f"user-memory-{user.id}" ) collection.delete(ids=[memory_id]) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 5ccb40d47..2eeb67207 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -11,7 +11,6 @@ import chromadb import requests import yaml from open_webui.apps.webui.internal.db import Base, get_db -from chromadb import Settings from open_webui.env import ( OPEN_WEBUI_DIR, DATA_DIR, @@ -926,22 +925,9 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( # RAG document content extraction #################################### -CONTENT_EXTRACTION_ENGINE = PersistentConfig( - "CONTENT_EXTRACTION_ENGINE", - "rag.CONTENT_EXTRACTION_ENGINE", - os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), -) - -TIKA_SERVER_URL = PersistentConfig( - "TIKA_SERVER_URL", - "rag.tika_server_url", - os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment -) - -#################################### -# RAG -#################################### +VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") +# Chroma CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) @@ -958,6 +944,23 @@ else: CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) +#################################### +# RAG +#################################### + +# RAG Content Extraction +CONTENT_EXTRACTION_ENGINE = PersistentConfig( + "CONTENT_EXTRACTION_ENGINE", + "rag.CONTENT_EXTRACTION_ENGINE", + os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), +) + +TIKA_SERVER_URL = PersistentConfig( + "TIKA_SERVER_URL", + "rag.tika_server_url", + os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment +) + RAG_TOP_K = PersistentConfig( "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5")) )