mirror of
https://github.com/open-webui/open-webui
synced 2025-01-29 13:58:09 +00:00
refac
This commit is contained in:
parent
1023ff8454
commit
4354f270ce
@ -12,11 +12,16 @@ from typing import Iterator, Optional, Sequence, Union
|
||||
|
||||
import requests
|
||||
import validators
|
||||
|
||||
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.apps.rag.search.main import SearchResult
|
||||
from open_webui.apps.rag.search.brave import search_brave
|
||||
from open_webui.apps.rag.search.duckduckgo import search_duckduckgo
|
||||
from open_webui.apps.rag.search.google_pse import search_google_pse
|
||||
from open_webui.apps.rag.search.jina_search import search_jina
|
||||
from open_webui.apps.rag.search.main import SearchResult
|
||||
from open_webui.apps.rag.search.searchapi import search_searchapi
|
||||
from open_webui.apps.rag.search.searxng import search_searxng
|
||||
from open_webui.apps.rag.search.serper import search_serper
|
||||
@ -33,10 +38,8 @@ from open_webui.apps.rag.utils import (
|
||||
)
|
||||
from open_webui.apps.webui.models.documents import DocumentForm, Documents
|
||||
from open_webui.apps.webui.models.files import Files
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
from open_webui.config import (
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_OVERLAP,
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
@ -84,9 +87,17 @@ from open_webui.config import (
|
||||
AppConfig,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
|
||||
from open_webui.utils.misc import (
|
||||
calculate_sha256,
|
||||
calculate_sha256_string,
|
||||
extract_folders_after_data_docs,
|
||||
sanitize_filename,
|
||||
)
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
|
||||
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
BSHTMLLoader,
|
||||
@ -105,14 +116,6 @@ from langchain_community.document_loaders import (
|
||||
YoutubeLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.misc import (
|
||||
calculate_sha256,
|
||||
calculate_sha256_string,
|
||||
extract_folders_after_data_docs,
|
||||
sanitize_filename,
|
||||
)
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@ -143,13 +146,11 @@ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SI
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
|
||||
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||
|
||||
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
||||
|
||||
|
||||
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
||||
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||
|
||||
@ -998,12 +999,12 @@ def store_docs_in_vector_db(
|
||||
|
||||
try:
|
||||
if overwrite:
|
||||
for collection in CHROMA_CLIENT.list_collections():
|
||||
for collection in VECTOR_DB_CLIENT.list_collections():
|
||||
if collection_name == collection.name:
|
||||
log.info(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
VECTOR_DB_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
collection = VECTOR_DB_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
embedding_func = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
@ -1018,7 +1019,7 @@ def store_docs_in_vector_db(
|
||||
embeddings = embedding_func(embedding_texts)
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
api=VECTOR_DB_CLIENT,
|
||||
ids=[str(uuid.uuid4()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings,
|
||||
@ -1396,7 +1397,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
|
||||
|
||||
@app.post("/reset/db")
|
||||
def reset_vector_db(user=Depends(get_admin_user)):
|
||||
CHROMA_CLIENT.reset()
|
||||
VECTOR_DB_CLIENT.reset()
|
||||
|
||||
|
||||
@app.post("/reset/uploads")
|
||||
@ -1437,7 +1438,7 @@ def reset(user=Depends(get_admin_user)) -> bool:
|
||||
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
|
||||
|
||||
try:
|
||||
CHROMA_CLIENT.reset()
|
||||
VECTOR_DB_CLIENT.reset()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
|
@ -3,18 +3,23 @@ import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbeddingsForm,
|
||||
generate_ollama_embeddings,
|
||||
)
|
||||
from open_webui.config import CHROMA_CLIENT
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbeddingsForm,
|
||||
generate_ollama_embeddings,
|
||||
)
|
||||
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
@ -26,12 +31,10 @@ def query_doc(
|
||||
k: int,
|
||||
):
|
||||
try:
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
query_embeddings = embedding_function(query)
|
||||
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
result = VECTOR_DB_CLIENT.query_collection(
|
||||
name=collection_name,
|
||||
query_embeddings=embedding_function(query),
|
||||
k=k,
|
||||
)
|
||||
|
||||
log.info(f"query_doc:result {result}")
|
||||
@ -49,7 +52,7 @@ def query_doc_with_hybrid_search(
|
||||
r: float,
|
||||
):
|
||||
try:
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
|
||||
documents = collection.get() # get all documents
|
||||
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
|
4
backend/open_webui/apps/rag/vector/connector.py
Normal file
4
backend/open_webui/apps/rag/vector/connector.py
Normal file
@ -0,0 +1,4 @@
|
||||
from open_webui.apps.rag.vector.dbs.chroma import Chroma
|
||||
from open_webui.config import VECTOR_DB
|
||||
|
||||
VECTOR_DB_CLIENT = Chroma()
|
58
backend/open_webui/apps/rag/vector/dbs/chroma.py
Normal file
58
backend/open_webui/apps/rag/vector/dbs/chroma.py
Normal file
@ -0,0 +1,58 @@
|
||||
import chromadb
|
||||
from chromadb import Settings
|
||||
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
CHROMA_HTTP_PORT,
|
||||
CHROMA_HTTP_HEADERS,
|
||||
CHROMA_HTTP_SSL,
|
||||
CHROMA_TENANT,
|
||||
CHROMA_DATABASE,
|
||||
)
|
||||
|
||||
|
||||
class Chroma:
|
||||
def __init__(self):
|
||||
if CHROMA_HTTP_HOST != "":
|
||||
self.client = chromadb.HttpClient(
|
||||
host=CHROMA_HTTP_HOST,
|
||||
port=CHROMA_HTTP_PORT,
|
||||
headers=CHROMA_HTTP_HEADERS,
|
||||
ssl=CHROMA_HTTP_SSL,
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
||||
)
|
||||
else:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=CHROMA_DATA_PATH,
|
||||
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
)
|
||||
|
||||
def query_collection(self, name, query_embeddings, k):
|
||||
collection = self.client.get_collection(name=name)
|
||||
if collection:
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
)
|
||||
return result
|
||||
return None
|
||||
|
||||
def list_collections(self):
|
||||
return self.client.list_collections()
|
||||
|
||||
def create_collection(self, name):
|
||||
return self.client.create_collection(name=name)
|
||||
|
||||
def get_or_create_collection(self, name):
|
||||
return self.client.get_or_create_collection(name=name)
|
||||
|
||||
def delete_collection(self, name):
|
||||
return self.client.delete_collection(name=name)
|
||||
|
||||
def reset(self):
|
||||
return self.client.reset()
|
0
backend/open_webui/apps/rag/vector/dbs/milvus.py
Normal file
0
backend/open_webui/apps/rag/vector/dbs/milvus.py
Normal file
@ -1,12 +1,13 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.models.memories import Memories, MemoryModel
|
||||
from open_webui.config import CHROMA_CLIENT
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.utils import get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@ -51,7 +52,9 @@ async def add_memory(
|
||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
|
||||
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
|
||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
||||
name=f"user-memory-{user.id}"
|
||||
)
|
||||
collection.upsert(
|
||||
documents=[memory.content],
|
||||
ids=[memory.id],
|
||||
@ -77,7 +80,9 @@ async def query_memory(
|
||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
||||
):
|
||||
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
|
||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
||||
name=f"user-memory-{user.id}"
|
||||
)
|
||||
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
@ -94,8 +99,10 @@ async def query_memory(
|
||||
async def reset_memory_from_vector_db(
|
||||
request: Request, user=Depends(get_verified_user)
|
||||
):
|
||||
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
|
||||
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
||||
name=f"user-memory-{user.id}"
|
||||
)
|
||||
|
||||
memories = Memories.get_memories_by_user_id(user.id)
|
||||
for memory in memories:
|
||||
@ -119,7 +126,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
|
||||
|
||||
if result:
|
||||
try:
|
||||
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
return True
|
||||
@ -145,7 +152,7 @@ async def update_memory_by_id(
|
||||
|
||||
if form_data.content is not None:
|
||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(
|
||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
||||
name=f"user-memory-{user.id}"
|
||||
)
|
||||
collection.upsert(
|
||||
@ -170,7 +177,7 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
|
||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
||||
|
||||
if result:
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(
|
||||
collection = VECTOR_DB_CLIENT.get_or_create_collection(
|
||||
name=f"user-memory-{user.id}"
|
||||
)
|
||||
collection.delete(ids=[memory_id])
|
||||
|
@ -11,7 +11,6 @@ import chromadb
|
||||
import requests
|
||||
import yaml
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from chromadb import Settings
|
||||
from open_webui.env import (
|
||||
OPEN_WEBUI_DIR,
|
||||
DATA_DIR,
|
||||
@ -926,22 +925,9 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||
# RAG document content extraction
|
||||
####################################
|
||||
|
||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
"CONTENT_EXTRACTION_ENGINE",
|
||||
"rag.CONTENT_EXTRACTION_ENGINE",
|
||||
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
|
||||
)
|
||||
|
||||
TIKA_SERVER_URL = PersistentConfig(
|
||||
"TIKA_SERVER_URL",
|
||||
"rag.tika_server_url",
|
||||
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
|
||||
)
|
||||
|
||||
####################################
|
||||
# RAG
|
||||
####################################
|
||||
VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
|
||||
|
||||
# Chroma
|
||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
|
||||
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
|
||||
@ -958,6 +944,23 @@ else:
|
||||
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||
|
||||
####################################
|
||||
# RAG
|
||||
####################################
|
||||
|
||||
# RAG Content Extraction
|
||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
"CONTENT_EXTRACTION_ENGINE",
|
||||
"rag.CONTENT_EXTRACTION_ENGINE",
|
||||
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
|
||||
)
|
||||
|
||||
TIKA_SERVER_URL = PersistentConfig(
|
||||
"TIKA_SERVER_URL",
|
||||
"rag.tika_server_url",
|
||||
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
|
||||
)
|
||||
|
||||
RAG_TOP_K = PersistentConfig(
|
||||
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user