import json import logging import mimetypes import os import shutil import uuid from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union from fastapi import ( Depends, FastAPI, File, Form, HTTPException, UploadFile, Request, status, APIRouter, ) from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import tiktoken from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter from langchain_core.documents import Document from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT # Document loaders from open_webui.retrieval.loaders.main import Loader from open_webui.retrieval.loaders.youtube import YoutubeLoader # Web search engines from open_webui.retrieval.web.main import SearchResult from open_webui.retrieval.web.utils import get_web_loader from open_webui.retrieval.web.brave import search_brave from open_webui.retrieval.web.kagi import search_kagi from open_webui.retrieval.web.mojeek import search_mojeek from open_webui.retrieval.web.duckduckgo import search_duckduckgo from open_webui.retrieval.web.google_pse import search_google_pse from open_webui.retrieval.web.jina_search import search_jina from open_webui.retrieval.web.searchapi import search_searchapi from open_webui.retrieval.web.searxng import search_searxng from open_webui.retrieval.web.serper import search_serper from open_webui.retrieval.web.serply import search_serply from open_webui.retrieval.web.serpstack import search_serpstack from open_webui.retrieval.web.tavily import search_tavily from open_webui.retrieval.web.bing import search_bing from open_webui.retrieval.utils import ( get_embedding_function, get_model_path, query_collection, query_collection_with_hybrid_search, query_doc, query_doc_with_hybrid_search, ) from open_webui.utils.misc import ( calculate_sha256_string, ) from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import ( ENV, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, UPLOAD_DIR, DEFAULT_LOCALE, ) from open_webui.env import ( SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, ) from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) ########################################## # # Utility functions # ########################################## def update_embedding_model( request: Request, embedding_model: str, auto_update: bool = False, ): if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "": from sentence_transformers import SentenceTransformer try: request.app.state.sentence_transformer_ef = SentenceTransformer( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") request.app.state.sentence_transformer_ef = None else: request.app.state.sentence_transformer_ef = None def update_reranking_model( request: Request, reranking_model: str, auto_update: bool = False, ): if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): try: from open_webui.retrieval.models.colbert import ColBERT request.app.state.sentence_transformer_rf = ColBERT( get_model_path(reranking_model, auto_update), env="docker" if DOCKER else None, ) except Exception as e: log.error(f"ColBERT: {e}") request.app.state.sentence_transformer_rf = None request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: import sentence_transformers try: request.app.state.sentence_transformer_rf = ( sentence_transformers.CrossEncoder( get_model_path(reranking_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, ) ) except: log.error("CrossEncoder error") request.app.state.sentence_transformer_rf = None request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: request.app.state.sentence_transformer_rf = None update_embedding_model( request.app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( request.app.state.config.RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, ) ########################################## # # API routes # ########################################## router = APIRouter() request.app.state.EMBEDDING_FUNCTION = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.sentence_transformer_ef, ( request.app.state.config.OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" else request.app.state.config.OLLAMA_BASE_URL ), ( request.app.state.config.OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" else request.app.state.config.OLLAMA_API_KEY ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) class CollectionNameForm(BaseModel): collection_name: Optional[str] = None class ProcessUrlForm(CollectionNameForm): url: str class SearchForm(CollectionNameForm): query: str @router.get("/") async def get_status(request: Request): return { "status": True, "chunk_size": request.app.state.config.CHUNK_SIZE, "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, "template": request.app.state.config.RAG_TEMPLATE, "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, } @router.get("/embedding") async def get_embedding_config(request: Request, user=Depends(get_admin_user)): return { "status": True, "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": request.app.state.config.OPENAI_API_BASE_URL, "key": request.app.state.config.OPENAI_API_KEY, }, "ollama_config": { "url": request.app.state.config.OLLAMA_BASE_URL, "key": request.app.state.config.OLLAMA_API_KEY, }, } @router.get("/reranking") async def get_reraanking_config(request: Request, user=Depends(get_admin_user)): return { "status": True, "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } class OpenAIConfigForm(BaseModel): url: str key: str class OllamaConfigForm(BaseModel): url: str key: str class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None ollama_config: Optional[OllamaConfigForm] = None embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 @router.post("/embedding/update") async def update_embedding_config( request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config is not None: request.app.state.config.OPENAI_API_BASE_URL = ( form_data.openai_config.url ) request.app.state.config.OPENAI_API_KEY = form_data.openai_config.key if form_data.ollama_config is not None: request.app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url request.app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( form_data.embedding_batch_size ) update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL) request.app.state.EMBEDDING_FUNCTION = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.sentence_transformer_ef, ( request.app.state.config.OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" else request.app.state.config.OLLAMA_BASE_URL ), ( request.app.state.config.OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" else request.app.state.config.OLLAMA_API_KEY ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) return { "status": True, "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, "openai_config": { "url": request.app.state.config.OPENAI_API_BASE_URL, "key": request.app.state.config.OPENAI_API_KEY, }, "ollama_config": { "url": request.app.state.config.OLLAMA_BASE_URL, "key": request.app.state.config.OLLAMA_API_KEY, }, } except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(e), ) class RerankingModelUpdateForm(BaseModel): reranking_model: str @router.post("/reranking/update") async def update_reranking_config( request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True) return { "status": True, "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(e), ) @router.get("/config") async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, "chunk_size": request.app.state.config.CHUNK_SIZE, "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "file": { "max_size": request.app.state.config.FILE_MAX_SIZE, "max_count": request.app.state.config.FILE_MAX_COUNT, }, "youtube": { "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, }, "web": { "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, "serper_api_key": request.app.state.config.SERPER_API_KEY, "serply_api_key": request.app.state.config.SERPLY_API_KEY, "tavily_api_key": request.app.state.config.TAVILY_API_KEY, "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, "seaarchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, "jina_api_key": request.app.state.config.JINA_API_KEY, "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } class FileConfig(BaseModel): max_size: Optional[int] = None max_count: Optional[int] = None class ContentExtractionConfig(BaseModel): engine: str = "" tika_server_url: Optional[str] = None class ChunkParamUpdateForm(BaseModel): text_splitter: Optional[str] = None chunk_size: int chunk_overlap: int class YoutubeLoaderConfig(BaseModel): language: list[str] translation: Optional[str] = None proxy_url: str = "" class WebSearchConfig(BaseModel): enabled: bool engine: Optional[str] = None searxng_query_url: Optional[str] = None google_pse_api_key: Optional[str] = None google_pse_engine_id: Optional[str] = None brave_search_api_key: Optional[str] = None kagi_search_api_key: Optional[str] = None mojeek_search_api_key: Optional[str] = None serpstack_api_key: Optional[str] = None serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None serply_api_key: Optional[str] = None tavily_api_key: Optional[str] = None searchapi_api_key: Optional[str] = None searchapi_engine: Optional[str] = None jina_api_key: Optional[str] = None bing_search_v7_endpoint: Optional[str] = None bing_search_v7_subscription_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None class WebConfig(BaseModel): search: WebSearchConfig web_loader_ssl_verification: Optional[bool] = None class ConfigUpdateForm(BaseModel): pdf_extract_images: Optional[bool] = None file: Optional[FileConfig] = None content_extraction: Optional[ContentExtractionConfig] = None chunk: Optional[ChunkParamUpdateForm] = None youtube: Optional[YoutubeLoaderConfig] = None web: Optional[WebConfig] = None @router.post("/config/update") async def update_rag_config( request: Request, form_data: ConfigUpdateForm, user=Depends(get_admin_user) ): request.app.state.config.PDF_EXTRACT_IMAGES = ( form_data.pdf_extract_images if form_data.pdf_extract_images is not None else request.app.state.config.PDF_EXTRACT_IMAGES ) if form_data.file is not None: request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count if form_data.content_extraction is not None: log.info(f"Updating text settings: {form_data.content_extraction}") request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( form_data.content_extraction.engine ) request.app.state.config.TIKA_SERVER_URL = ( form_data.content_extraction.tika_server_url ) if form_data.chunk is not None: request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter request.app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size request.app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap if form_data.youtube is not None: request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation if form_data.web is not None: request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False form_data.web.web_loader_ssl_verification ) request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine request.app.state.config.SEARXNG_QUERY_URL = ( form_data.web.search.searxng_query_url ) request.app.state.config.GOOGLE_PSE_API_KEY = ( form_data.web.search.google_pse_api_key ) request.app.state.config.GOOGLE_PSE_ENGINE_ID = ( form_data.web.search.google_pse_engine_id ) request.app.state.config.BRAVE_SEARCH_API_KEY = ( form_data.web.search.brave_search_api_key ) request.app.state.config.KAGI_SEARCH_API_KEY = ( form_data.web.search.kagi_search_api_key ) request.app.state.config.MOJEEK_SEARCH_API_KEY = ( form_data.web.search.mojeek_search_api_key ) request.app.state.config.SERPSTACK_API_KEY = ( form_data.web.search.serpstack_api_key ) request.app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https request.app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key request.app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key request.app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key request.app.state.config.SEARCHAPI_API_KEY = ( form_data.web.search.searchapi_api_key ) request.app.state.config.SEARCHAPI_ENGINE = ( form_data.web.search.searchapi_engine ) request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( form_data.web.search.bing_search_v7_endpoint ) request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = ( form_data.web.search.bing_search_v7_subscription_key ) request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = ( form_data.web.search.result_count ) request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests ) return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, "file": { "max_size": request.app.state.config.FILE_MAX_SIZE, "max_count": request.app.state.config.FILE_MAX_COUNT, }, "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "tika_server_url": request.app.state.config.TIKA_SERVER_URL, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, "chunk_size": request.app.state.config.CHUNK_SIZE, "chunk_overlap": request.app.state.config.CHUNK_OVERLAP, }, "youtube": { "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID, "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY, "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY, "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY, "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY, "serpstack_https": request.app.state.config.SERPSTACK_HTTPS, "serper_api_key": request.app.state.config.SERPER_API_KEY, "serply_api_key": request.app.state.config.SERPLY_API_KEY, "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY, "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE, "tavily_api_key": request.app.state.config.TAVILY_API_KEY, "jina_api_key": request.app.state.config.JINA_API_KEY, "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT, "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } @router.get("/template") async def get_rag_template(request: Request, user=Depends(get_verified_user)): return { "status": True, "template": request.app.state.config.RAG_TEMPLATE, } @router.get("/query/settings") async def get_query_settings(request: Request, user=Depends(get_admin_user)): return { "status": True, "template": request.app.state.config.RAG_TEMPLATE, "k": request.app.state.config.TOP_K, "r": request.app.state.config.RELEVANCE_THRESHOLD, "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } class QuerySettingsForm(BaseModel): k: Optional[int] = None r: Optional[float] = None template: Optional[str] = None hybrid: Optional[bool] = None @router.post("/query/settings/update") async def update_query_settings( request: Request, form_data: QuerySettingsForm, user=Depends(get_admin_user) ): request.app.state.config.RAG_TEMPLATE = form_data.template request.app.state.config.TOP_K = form_data.k if form_data.k else 4 request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False ) return { "status": True, "template": request.app.state.config.RAG_TEMPLATE, "k": request.app.state.config.TOP_K, "r": request.app.state.config.RELEVANCE_THRESHOLD, "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, } #################################### # # Document process and retrieval # #################################### def save_docs_to_vector_db( request: Request, docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False, split: bool = True, add: bool = False, ) -> bool: def _get_docs_info(docs: list[Document]) -> str: docs_info = set() # Trying to select relevant metadata identifying the document. for doc in docs: metadata = getattr(doc, "metadata", {}) doc_name = metadata.get("name", "") if not doc_name: doc_name = metadata.get("title", "") if not doc_name: doc_name = metadata.get("source", "") if doc_name: docs_info.add(doc_name) return ", ".join(docs_info) log.info( f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" ) # Check if entries with the same hash (metadata.hash) already exist if metadata and "hash" in metadata: result = VECTOR_DB_CLIENT.query( collection_name=collection_name, filter={"hash": metadata["hash"]}, ) if result is not None: existing_doc_ids = result.ids[0] if existing_doc_ids: log.info(f"Document with hash {metadata['hash']} already exists") raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: if request.app.state.config.TEXT_SPLITTER in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( chunk_size=request.app.state.config.CHUNK_SIZE, chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) elif request.app.state.config.TEXT_SPLITTER == "token": log.info( f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" ) tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME)) text_splitter = TokenTextSplitter( encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME), chunk_size=request.app.state.config.CHUNK_SIZE, chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) else: raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) docs = text_splitter.split_documents(docs) if len(docs) == 0: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) texts = [doc.page_content for doc in docs] metadatas = [ { **doc.metadata, **(metadata if metadata else {}), "embedding_config": json.dumps( { "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, "model": request.app.state.config.RAG_EMBEDDING_MODEL, } ), } for doc in docs ] # ChromaDB does not like datetime formats # for meta-data so convert them to string. for metadata in metadatas: for key, value in metadata.items(): if isinstance(value, datetime): metadata[key] = str(value) try: if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): log.info(f"collection {collection_name} already exists") if overwrite: VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) log.info(f"deleting existing collection {collection_name}") elif add is False: log.info( f"collection {collection_name} already exists, overwrite is False and add is False" ) return True log.info(f"adding to collection {collection_name}") embedding_function = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.sentence_transformer_ef, ( request.app.state.config.OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" else request.app.state.config.OLLAMA_BASE_URL ), ( request.app.state.config.OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" else request.app.state.config.OLLAMA_API_KEY ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, ) embeddings = embedding_function( list(map(lambda x: x.replace("\n", " "), texts)) ) items = [ { "id": str(uuid.uuid4()), "text": text, "vector": embeddings[idx], "metadata": metadatas[idx], } for idx, text in enumerate(texts) ] VECTOR_DB_CLIENT.insert( collection_name=collection_name, items=items, ) return True except Exception as e: log.exception(e) raise e class ProcessFileForm(BaseModel): file_id: str content: Optional[str] = None collection_name: Optional[str] = None @router.post("/process/file") def process_file( request: Request, form_data: ProcessFileForm, user=Depends(get_verified_user), ): try: file = Files.get_file_by_id(form_data.file_id) collection_name = form_data.collection_name if collection_name is None: collection_name = f"file-{file.id}" if form_data.content: # Update the content in the file # Usage: /files/{file_id}/data/content/update VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}") docs = [ Document( page_content=form_data.content.replace("
", "\n"), metadata={ **file.meta, "name": file.filename, "created_by": file.user_id, "file_id": file.id, "source": file.filename, }, ) ] text_content = form_data.content elif form_data.collection_name: # Check if the file has already been processed and save the content # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update result = VECTOR_DB_CLIENT.query( collection_name=f"file-{file.id}", filter={"file_id": file.id} ) if result is not None and len(result.ids[0]) > 0: docs = [ Document( page_content=result.documents[0][idx], metadata=result.metadatas[0][idx], ) for idx, id in enumerate(result.ids[0]) ] else: docs = [ Document( page_content=file.data.get("content", ""), metadata={ **file.meta, "name": file.filename, "created_by": file.user_id, "file_id": file.id, "source": file.filename, }, ) ] text_content = file.data.get("content", "") else: # Process the file and save the content # Usage: /files/ file_path = file.path if file_path: file_path = Storage.get_file(file_path) loader = Loader( engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path ) docs = [ Document( page_content=doc.page_content, metadata={ **doc.metadata, "name": file.filename, "created_by": file.user_id, "file_id": file.id, "source": file.filename, }, ) for doc in docs ] else: docs = [ Document( page_content=file.data.get("content", ""), metadata={ **file.meta, "name": file.filename, "created_by": file.user_id, "file_id": file.id, "source": file.filename, }, ) ] text_content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {text_content}") Files.update_file_data_by_id( file.id, {"content": text_content}, ) hash = calculate_sha256_string(text_content) Files.update_file_hash_by_id(file.id, hash) try: result = save_docs_to_vector_db( request, docs=docs, collection_name=collection_name, metadata={ "file_id": file.id, "name": file.filename, "hash": hash, }, add=(True if form_data.collection_name else False), ) if result: Files.update_file_metadata_by_id( file.id, { "collection_name": collection_name, }, ) return { "status": True, "collection_name": collection_name, "filename": file.filename, "content": text_content, } except Exception as e: raise e except Exception as e: log.exception(e) if "No pandoc was found" in str(e): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) class ProcessTextForm(BaseModel): name: str content: str collection_name: Optional[str] = None @router.post("/process/text") def process_text( request: Request, form_data: ProcessTextForm, user=Depends(get_verified_user), ): collection_name = form_data.collection_name if collection_name is None: collection_name = calculate_sha256_string(form_data.content) docs = [ Document( page_content=form_data.content, metadata={"name": form_data.name, "created_by": user.id}, ) ] text_content = form_data.content log.debug(f"text_content: {text_content}") result = save_docs_to_vector_db(request, docs, collection_name) if result: return { "status": True, "collection_name": collection_name, "content": text_content, } else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(), ) @router.post("/process/youtube") def process_youtube_video( request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) ): try: collection_name = form_data.collection_name if not collection_name: collection_name = calculate_sha256_string(form_data.url)[:63] loader = YoutubeLoader( form_data.url, language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, "filename": form_data.url, "file": { "data": { "content": content, }, "meta": { "name": form_data.url, }, }, } except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) @router.post("/process/web") def process_web( request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) ): try: collection_name = form_data.collection_name if not collection_name: collection_name = calculate_sha256_string(form_data.url)[:63] loader = get_web_loader( form_data.url, verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.load() content = " ".join([doc.page_content for doc in docs]) log.debug(f"text_content: {content}") save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, "filename": form_data.url, "file": { "data": { "content": content, }, "meta": { "name": form_data.url, }, }, } except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID - BRAVE_SEARCH_API_KEY - KAGI_SEARCH_API_KEY - MOJEEK_SEARCH_API_KEY - SERPSTACK_API_KEY - SERPER_API_KEY - SERPLY_API_KEY - TAVILY_API_KEY - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`) Args: query (str): The query to search for """ # TODO: add playwright to search the web if engine == "searxng": if request.app.state.config.SEARXNG_QUERY_URL: return search_searxng( request.app.state.config.SEARXNG_QUERY_URL, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARXNG_QUERY_URL found in environment variables") elif engine == "google_pse": if ( request.app.state.config.GOOGLE_PSE_API_KEY and request.app.state.config.GOOGLE_PSE_ENGINE_ID ): return search_google_pse( request.app.state.config.GOOGLE_PSE_API_KEY, request.app.state.config.GOOGLE_PSE_ENGINE_ID, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception( "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" ) elif engine == "brave": if request.app.state.config.BRAVE_SEARCH_API_KEY: return search_brave( request.app.state.config.BRAVE_SEARCH_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") elif engine == "kagi": if request.app.state.config.KAGI_SEARCH_API_KEY: return search_kagi( request.app.state.config.KAGI_SEARCH_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No KAGI_SEARCH_API_KEY found in environment variables") elif engine == "mojeek": if request.app.state.config.MOJEEK_SEARCH_API_KEY: return search_mojeek( request.app.state.config.MOJEEK_SEARCH_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables") elif engine == "serpstack": if request.app.state.config.SERPSTACK_API_KEY: return search_serpstack( request.app.state.config.SERPSTACK_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, https_enabled=request.app.state.config.SERPSTACK_HTTPS, ) else: raise Exception("No SERPSTACK_API_KEY found in environment variables") elif engine == "serper": if request.app.state.config.SERPER_API_KEY: return search_serper( request.app.state.config.SERPER_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPER_API_KEY found in environment variables") elif engine == "serply": if request.app.state.config.SERPLY_API_KEY: return search_serply( request.app.state.config.SERPLY_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": return search_duckduckgo( query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) elif engine == "tavily": if request.app.state.config.TAVILY_API_KEY: return search_tavily( request.app.state.config.TAVILY_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception("No TAVILY_API_KEY found in environment variables") elif engine == "searchapi": if request.app.state.config.SEARCHAPI_API_KEY: return search_searchapi( request.app.state.config.SEARCHAPI_API_KEY, request.app.state.config.SEARCHAPI_ENGINE, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SEARCHAPI_API_KEY found in environment variables") elif engine == "jina": return search_jina( request.app.state.config.JINA_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) elif engine == "bing": return search_bing( request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, request.app.state.config.BING_SEARCH_V7_ENDPOINT, str(DEFAULT_LOCALE), query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No search engine API key found in environment variables") @router.post("/process/web/search") def process_web_search( request: Request, form_data: SearchForm, user=Depends(get_verified_user) ): try: logging.info( f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" ) web_results = search_web( request, request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query ) except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), ) try: collection_name = form_data.collection_name if collection_name == "": collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[ :63 ] urls = [result.link for result in web_results] loader = get_web_loader( urls=urls, verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) docs = loader.aload() save_docs_to_vector_db(request, docs, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, "filenames": urls, } except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) class QueryDocForm(BaseModel): collection_name: str query: str k: Optional[int] = None r: Optional[float] = None hybrid: Optional[bool] = None @router.post("/query/doc") def query_doc_handler( request: Request, form_data: QueryDocForm, user=Depends(get_verified_user), ): try: if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, embedding_function=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=request.app.state.sentence_transformer_rf, r=( form_data.r if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_doc( collection_name=form_data.collection_name, query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query), k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) class QueryCollectionsForm(BaseModel): collection_names: list[str] query: str k: Optional[int] = None r: Optional[float] = None hybrid: Optional[bool] = None @router.post("/query/collection") def query_collection_handler( request: Request, form_data: QueryCollectionsForm, user=Depends(get_verified_user), ): try: if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], embedding_function=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=request.app.state.sentence_transformer_rf, r=( form_data.r if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_collection( collection_names=form_data.collection_names, queries=[form_data.query], embedding_function=request.app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else request.app.state.config.TOP_K, ) except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) #################################### # # Vector DB operations # #################################### class DeleteForm(BaseModel): collection_name: str file_id: str @router.post("/delete") def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): try: if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): file = Files.get_file_by_id(form_data.file_id) hash = file.hash VECTOR_DB_CLIENT.delete( collection_name=form_data.collection_name, metadata={"hash": hash}, ) return {"status": True} else: return {"status": False} except Exception as e: log.exception(e) return {"status": False} @router.post("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): VECTOR_DB_CLIENT.reset() Knowledges.delete_all_knowledge() @router.post("/reset/uploads") def reset_upload_dir(user=Depends(get_admin_user)) -> bool: folder = f"{UPLOAD_DIR}" try: # Check if the directory exists if os.path.exists(folder): # Iterate over all the files and directories in the specified directory for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) # Remove the file or link elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: print(f"Failed to delete {file_path}. Reason: {e}") else: print(f"The directory {folder} does not exist") except Exception as e: print(f"Failed to process the directory {folder}. Reason: {e}") return True if ENV == "dev": @router.get("/ef/{text}") async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): return {"result": request.app.state.EMBEDDING_FUNCTION(text)}