This commit is contained in:
Timothy J. Baek 2024-09-10 02:27:50 +01:00
parent 1023ff8454
commit 4354f270ce
7 changed files with 138 additions and 62 deletions

View File

@ -12,11 +12,16 @@ from typing import Iterator, Optional, Sequence, Union
import requests
import validators
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.apps.rag.search.main import SearchResult
from open_webui.apps.rag.search.brave import search_brave
from open_webui.apps.rag.search.duckduckgo import search_duckduckgo
from open_webui.apps.rag.search.google_pse import search_google_pse
from open_webui.apps.rag.search.jina_search import search_jina
from open_webui.apps.rag.search.main import SearchResult
from open_webui.apps.rag.search.searchapi import search_searchapi
from open_webui.apps.rag.search.searxng import search_searxng
from open_webui.apps.rag.search.serper import search_serper
@ -33,10 +38,8 @@ from open_webui.apps.rag.utils import (
)
from open_webui.apps.webui.models.documents import DocumentForm, Documents
from open_webui.apps.webui.models.files import Files
from chromadb.utils.batch_utils import create_batches
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
CHROMA_CLIENT,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
@ -84,9 +87,17 @@ from open_webui.config import (
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
from open_webui.utils.misc import (
calculate_sha256,
calculate_sha256_string,
extract_folders_after_data_docs,
sanitize_filename,
)
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from chromadb.utils.batch_utils import create_batches
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
BSHTMLLoader,
@ -105,14 +116,6 @@ from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
from pydantic import BaseModel
from open_webui.utils.misc import (
calculate_sha256,
calculate_sha256_string,
extract_folders_after_data_docs,
sanitize_filename,
)
from open_webui.utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -143,13 +146,11 @@ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SI
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
@ -998,12 +999,12 @@ def store_docs_in_vector_db(
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
for collection in VECTOR_DB_CLIENT.list_collections():
if collection_name == collection.name:
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
VECTOR_DB_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection = VECTOR_DB_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
@ -1018,7 +1019,7 @@ def store_docs_in_vector_db(
embeddings = embedding_func(embedding_texts)
for batch in create_batches(
api=CHROMA_CLIENT,
api=VECTOR_DB_CLIENT,
ids=[str(uuid.uuid4()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
@ -1396,7 +1397,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
@app.post("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()
VECTOR_DB_CLIENT.reset()
@app.post("/reset/uploads")
@ -1437,7 +1438,7 @@ def reset(user=Depends(get_admin_user)) -> bool:
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
try:
CHROMA_CLIENT.reset()
VECTOR_DB_CLIENT.reset()
except Exception as e:
log.exception(e)

View File

@ -3,18 +3,23 @@ import os
from typing import Optional, Union
import requests
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
)
from open_webui.config import CHROMA_CLIENT
from open_webui.env import SRC_LOG_LEVELS
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.apps.ollama.main import (
GenerateEmbeddingsForm,
generate_ollama_embeddings,
)
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -26,12 +31,10 @@ def query_doc(
k: int,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
query_embeddings = embedding_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
result = VECTOR_DB_CLIENT.query_collection(
name=collection_name,
query_embeddings=embedding_function(query),
k=k,
)
log.info(f"query_doc:result {result}")
@ -49,7 +52,7 @@ def query_doc_with_hybrid_search(
r: float,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts(

View File

@ -0,0 +1,4 @@
from open_webui.apps.rag.vector.dbs.chroma import Chroma
from open_webui.config import VECTOR_DB
VECTOR_DB_CLIENT = Chroma()

View File

@ -0,0 +1,58 @@
import chromadb
from chromadb import Settings
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
CHROMA_HTTP_PORT,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
)
class Chroma:
def __init__(self):
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
def query_collection(self, name, query_embeddings, k):
collection = self.client.get_collection(name=name)
if collection:
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
return result
return None
def list_collections(self):
return self.client.list_collections()
def create_collection(self, name):
return self.client.create_collection(name=name)
def get_or_create_collection(self, name):
return self.client.get_or_create_collection(name=name)
def delete_collection(self, name):
return self.client.delete_collection(name=name)
def reset(self):
return self.client.reset()

View File

@ -1,12 +1,13 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
import logging
from typing import Optional
from open_webui.apps.webui.models.memories import Memories, MemoryModel
from open_webui.config import CHROMA_CLIENT
from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.utils import get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -51,7 +52,9 @@ async def add_memory(
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
@ -77,7 +80,9 @@ async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
results = collection.query(
query_embeddings=[query_embedding],
@ -94,8 +99,10 @@ async def query_memory(
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
@ -119,7 +126,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
if result:
try:
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
except Exception as e:
log.error(e)
return True
@ -145,7 +152,7 @@ async def update_memory_by_id(
if form_data.content is not None:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
@ -170,7 +177,7 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
collection = CHROMA_CLIENT.get_or_create_collection(
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.delete(ids=[memory_id])

View File

@ -11,7 +11,6 @@ import chromadb
import requests
import yaml
from open_webui.apps.webui.internal.db import Base, get_db
from chromadb import Settings
from open_webui.env import (
OPEN_WEBUI_DIR,
DATA_DIR,
@ -926,22 +925,9 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
####################################
# RAG
####################################
VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
# Chroma
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
@ -958,6 +944,23 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
####################################
# RAG
####################################
# RAG Content Extraction
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
RAG_TOP_K = PersistentConfig(
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
)