from fastapi import ( FastAPI, Depends, HTTPException, status, UploadFile, File, Form, ) from fastapi.middleware.cors import CORSMiddleware import os, shutil, logging, re from pathlib import Path from typing import List, Union, Sequence from chromadb.utils.batch_utils import create_batches from langchain_community.document_loaders import ( WebBaseLoader, TextLoader, PyPDFLoader, CSVLoader, BSHTMLLoader, Docx2txtLoader, UnstructuredEPubLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, UnstructuredXMLLoader, UnstructuredRSTLoader, UnstructuredExcelLoader, UnstructuredPowerPointLoader, YoutubeLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter import validators import urllib.parse import socket from pydantic import BaseModel from typing import Optional import mimetypes import uuid import json import sentence_transformers from apps.webui.models.documents import ( Documents, DocumentForm, DocumentResponse, ) from apps.rag.utils import ( get_model_path, get_embedding_function, query_doc, query_doc_with_hybrid_search, query_collection, query_collection_with_hybrid_search, ) from apps.rag.search.brave import search_brave from apps.rag.search.google_pse import search_google_pse from apps.rag.search.main import SearchResult from apps.rag.search.searxng import search_searxng from apps.rag.search.serper import search_serper from apps.rag.search.serpstack import search_serpstack from utils.misc import ( calculate_sha256, calculate_sha256_string, sanitize_filename, extract_folders_after_data_docs, ) from utils.utils import get_current_user, get_admin_user from config import ( ENV, SRC_LOG_LEVELS, UPLOAD_DIR, DOCS_DIR, RAG_TOP_K, RAG_RELEVANCE_THRESHOLD, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, RAG_RERANKING_MODEL, PDF_EXTRACT_IMAGES, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, DEVICE_TYPE, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, RAG_TEMPLATE, ENABLE_RAG_LOCAL_WEB_FETCH, YOUTUBE_LOADER_LANGUAGE, ENABLE_RAG_WEB_SEARCH, RAG_WEB_SEARCH_ENGINE, SEARXNG_QUERY_URL, GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, BRAVE_SEARCH_API_KEY, SERPSTACK_API_KEY, SERPSTACK_HTTPS, SERPER_API_KEY, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, AppConfig, ) from constants import ERROR_MESSAGES log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) app = FastAPI() app.state.config = AppConfig() app.state.config.TOP_K = RAG_TOP_K app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) app.state.config.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS def update_embedding_model( embedding_model: str, update_model: bool = False, ): if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( get_model_path(embedding_model, update_model), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ) else: app.state.sentence_transformer_ef = None def update_reranking_model( reranking_model: str, update_model: bool = False, ): if reranking_model: app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( get_model_path(reranking_model, update_model), device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, ) else: app.state.sentence_transformer_rf = None update_embedding_model( app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) update_reranking_model( app.state.config.RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, ) app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, ) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class CollectionNameForm(BaseModel): collection_name: Optional[str] = "test" class UrlForm(CollectionNameForm): url: str class SearchForm(CollectionNameForm): query: str @app.get("/") async def get_status(): return { "status": True, "chunk_size": app.state.config.CHUNK_SIZE, "chunk_overlap": app.state.config.CHUNK_OVERLAP, "template": app.state.config.RAG_TEMPLATE, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL, } @app.get("/embedding") async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, } @app.get("/reranking") async def get_reraanking_config(user=Depends(get_admin_user)): return { "status": True, "reranking_model": app.state.config.RAG_RERANKING_MODEL, } class OpenAIConfigForm(BaseModel): url: str key: str class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None embedding_engine: str embedding_model: str @app.post("/embedding/update") async def update_embedding_config( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): log.info( f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) try: app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if form_data.openai_config != None: app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_KEY = form_data.openai_config.key update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL) app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, ) return { "status": True, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL, "openai_config": { "url": app.state.config.OPENAI_API_BASE_URL, "key": app.state.config.OPENAI_API_KEY, }, } except Exception as e: log.exception(f"Problem updating embedding model: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(e), ) class RerankingModelUpdateForm(BaseModel): reranking_model: str @app.post("/reranking/update") async def update_reranking_config( form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) ): log.info( f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" ) try: app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True return { "status": True, "reranking_model": app.state.config.RAG_RERANKING_MODEL, } except Exception as e: log.exception(f"Problem updating reranking model: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(e), ) @app.get("/config") async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { "chunk_size": app.state.config.CHUNK_SIZE, "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, "youtube": { "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { "enable": app.state.config.ENABLE_RAG_WEB_SEARCH, "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } class ChunkParamUpdateForm(BaseModel): chunk_size: int chunk_overlap: int class YoutubeLoaderConfig(BaseModel): language: List[str] translation: Optional[str] = None class WebSearchConfig(BaseModel): enable: bool engine: Optional[str] = None searxng_query_url: Optional[str] = None google_pse_api_key: Optional[str] = None google_pse_engine_id: Optional[str] = None brave_search_api_key: Optional[str] = None serpstack_api_key: Optional[str] = None serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None class WebConfig(BaseModel): search: WebSearchConfig web_loader_ssl_verification: Optional[bool] = None class ConfigUpdateForm(BaseModel): pdf_extract_images: Optional[bool] = None chunk: Optional[ChunkParamUpdateForm] = None youtube: Optional[YoutubeLoaderConfig] = None web: Optional[WebConfig] = None @app.post("/config/update") async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): app.state.config.PDF_EXTRACT_IMAGES = ( form_data.pdf_extract_images if form_data.pdf_extract_images is not None else app.state.config.PDF_EXTRACT_IMAGES ) if form_data.chunk is not None: app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap if form_data.youtube is not None: app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation if form_data.web is not None: app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( form_data.web.web_loader_ssl_verification ) app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enable app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key app.state.config.GOOGLE_PSE_ENGINE_ID = ( form_data.web.search.google_pse_engine_id ) app.state.config.BRAVE_SEARCH_API_KEY = ( form_data.web.search.brave_search_api_key ) app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests ) return { "status": True, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "chunk": { "chunk_size": app.state.config.CHUNK_SIZE, "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, "youtube": { "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { "enable": app.state.config.ENABLE_RAG_WEB_SEARCH, "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, }, } @app.get("/template") async def get_rag_template(user=Depends(get_current_user)): return { "status": True, "template": app.state.config.RAG_TEMPLATE, } @app.get("/query/settings") async def get_query_settings(user=Depends(get_admin_user)): return { "status": True, "template": app.state.config.RAG_TEMPLATE, "k": app.state.config.TOP_K, "r": app.state.config.RELEVANCE_THRESHOLD, "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } class QuerySettingsForm(BaseModel): k: Optional[int] = None r: Optional[float] = None template: Optional[str] = None hybrid: Optional[bool] = None @app.post("/query/settings/update") async def update_query_settings( form_data: QuerySettingsForm, user=Depends(get_admin_user) ): app.state.config.RAG_TEMPLATE = ( form_data.template if form_data.template else RAG_TEMPLATE ) app.state.config.TOP_K = form_data.k if form_data.k else 4 app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( form_data.hybrid if form_data.hybrid else False ) return { "status": True, "template": app.state.config.RAG_TEMPLATE, "k": app.state.config.TOP_K, "r": app.state.config.RELEVANCE_THRESHOLD, "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH, } class QueryDocForm(BaseModel): collection_name: str query: str k: Optional[int] = None r: Optional[float] = None hybrid: Optional[bool] = None @app.post("/query/doc") def query_doc_handler( form_data: QueryDocForm, user=Depends(get_current_user), ): try: if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_doc_with_hybrid_search( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, r=( form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_doc( collection_name=form_data.collection_name, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) class QueryCollectionsForm(BaseModel): collection_names: List[str] query: str k: Optional[int] = None r: Optional[float] = None hybrid: Optional[bool] = None @app.post("/query/collection") def query_collection_handler( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): try: if app.state.config.ENABLE_RAG_HYBRID_SEARCH: return query_collection_with_hybrid_search( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.config.TOP_K, reranking_function=app.state.sentence_transformer_rf, r=( form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD ), ) else: return query_collection( collection_names=form_data.collection_names, query=form_data.query, embedding_function=app.state.EMBEDDING_FUNCTION, k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) @app.post("/youtube") def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): try: loader = YoutubeLoader.from_youtube_url( form_data.url, add_video_info=True, language=app.state.config.YOUTUBE_LOADER_LANGUAGE, translation=app.state.YOUTUBE_LOADER_TRANSLATION, ) data = loader.load() collection_name = form_data.collection_name if collection_name == "": collection_name = calculate_sha256_string(form_data.url)[:63] store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, "filename": form_data.url, } except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) @app.post("/web") def store_web(form_data: UrlForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = get_web_loader( form_data.url, verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ) data = loader.load() collection_name = form_data.collection_name if collection_name == "": collection_name = calculate_sha256_string(form_data.url)[:63] store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, "filename": form_data.url, } except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): # Check if the URL is valid if not validate_url(url): raise ValueError(ERROR_MESSAGES.INVALID_URL) return WebBaseLoader( url, verify_ssl=verify_ssl, requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS, continue_on_failure=True, ) def validate_url(url: Union[str, Sequence[str]]): if isinstance(url, str): if isinstance(validators.url(url), validators.ValidationError): raise ValueError(ERROR_MESSAGES.INVALID_URL) if not ENABLE_RAG_LOCAL_WEB_FETCH: # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses parsed_url = urllib.parse.urlparse(url) # Get IPv4 and IPv6 addresses ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) # Check if any of the resolved addresses are private # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader for ip in ipv4_addresses: if validators.ipv4(ip, private=True): raise ValueError(ERROR_MESSAGES.INVALID_URL) for ip in ipv6_addresses: if validators.ipv6(ip, private=True): raise ValueError(ERROR_MESSAGES.INVALID_URL) return True elif isinstance(url, Sequence): return all(validate_url(u) for u in url) else: return False def resolve_hostname(hostname): # Get address information addr_info = socket.getaddrinfo(hostname, None) # Extract IP addresses from address information ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] return ipv4_addresses, ipv6_addresses def search_web(engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID - BRAVE_SEARCH_API_KEY - SERPSTACK_API_KEY - SERPER_API_KEY Args: query (str): The query to search for """ # TODO: add playwright to search the web if engine == "searxng": if app.state.config.SEARXNG_QUERY_URL: return search_searxng( app.state.config.SEARXNG_QUERY_URL, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception("No SEARXNG_QUERY_URL found in environment variables") elif engine == "google_pse": if ( app.state.config.GOOGLE_PSE_API_KEY and app.state.config.GOOGLE_PSE_ENGINE_ID ): return search_google_pse( app.state.config.GOOGLE_PSE_API_KEY, app.state.config.GOOGLE_PSE_ENGINE_ID, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception( "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" ) elif engine == "brave": if app.state.config.BRAVE_SEARCH_API_KEY: return search_brave( app.state.config.BRAVE_SEARCH_API_KEY, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") elif engine == "serpstack": if app.state.config.SERPSTACK_API_KEY: return search_serpstack( app.state.config.SERPSTACK_API_KEY, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, https_enabled=app.state.config.SERPSTACK_HTTPS, ) else: raise Exception("No SERPSTACK_API_KEY found in environment variables") elif engine == "serper": if app.state.config.SERPER_API_KEY: return search_serper( app.state.config.SERPER_API_KEY, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, ) else: raise Exception("No SERPER_API_KEY found in environment variables") else: raise Exception("No search engine API key found in environment variables") @app.post("/web/search") def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): try: web_results = search_web( app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query ) except Exception as e: log.exception(e) print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), ) try: urls = [result.link for result in web_results] loader = get_web_loader(urls) data = loader.load() collection_name = form_data.collection_name if collection_name == "": collection_name = calculate_sha256_string(form_data.query)[:63] store_data_in_vector_db(data, collection_name, overwrite=True) return { "status": True, "collection_name": collection_name, "filenames": urls, } except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.config.CHUNK_SIZE, chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) docs = text_splitter.split_documents(data) if len(docs) > 0: log.info(f"store_data_in_vector_db {docs}") return store_docs_in_vector_db(docs, collection_name, overwrite), None else: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) def store_text_in_vector_db( text, metadata, collection_name, overwrite: bool = False ) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.config.CHUNK_SIZE, chunk_overlap=app.state.config.CHUNK_OVERLAP, add_start_index=True, ) docs = text_splitter.create_documents([text], metadatas=[metadata]) return store_docs_in_vector_db(docs, collection_name, overwrite) def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: log.info(f"store_docs_in_vector_db {docs} {collection_name}") texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] try: if overwrite: for collection in CHROMA_CLIENT.list_collections(): if collection_name == collection.name: log.info(f"deleting existing collection {collection_name}") CHROMA_CLIENT.delete_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name) embedding_func = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, app.state.config.OPENAI_API_KEY, app.state.config.OPENAI_API_BASE_URL, ) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) embeddings = embedding_func(embedding_texts) for batch in create_batches( api=CHROMA_CLIENT, ids=[str(uuid.uuid4()) for _ in texts], metadatas=metadatas, embeddings=embeddings, documents=texts, ): collection.add(*batch) return True except Exception as e: log.exception(e) if e.__class__.__name__ == "UniqueConstraintError": return True return False def get_loader(filename: str, file_content_type: str, file_path: str): file_ext = filename.split(".")[-1].lower() known_type = True known_source_ext = [ "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", "ts", "css", "cpp", "hpp", "h", "c", "cs", "sql", "log", "ini", "pl", "pm", "r", "dart", "dockerfile", "env", "php", "hs", "hsc", "lua", "nginxconf", "conf", "m", "mm", "plsql", "perl", "rb", "rs", "db2", "scala", "bash", "swift", "vue", "svelte", ] if file_ext == "pdf": loader = PyPDFLoader( file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES ) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": loader = UnstructuredRSTLoader(file_path, mode="elements") elif file_ext == "xml": loader = UnstructuredXMLLoader(file_path) elif file_ext in ["htm", "html"]: loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") elif file_ext == "md": loader = UnstructuredMarkdownLoader(file_path) elif file_content_type == "application/epub+zip": loader = UnstructuredEPubLoader(file_path) elif ( file_content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" or file_ext in ["doc", "docx"] ): loader = Docx2txtLoader(file_path) elif file_content_type in [ "application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ] or file_ext in ["xls", "xlsx"]: loader = UnstructuredExcelLoader(file_path) elif file_content_type in [ "application/vnd.ms-powerpoint", "application/vnd.openxmlformats-officedocument.presentationml.presentation", ] or file_ext in ["ppt", "pptx"]: loader = UnstructuredPowerPointLoader(file_path) elif file_ext in known_source_ext or ( file_content_type and file_content_type.find("text/") >= 0 ): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TextLoader(file_path, autodetect_encoding=True) known_type = False return loader, known_type @app.post("/doc") def store_doc( collection_name: Optional[str] = Form(None), file: UploadFile = File(...), user=Depends(get_current_user), ): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" log.info(f"file.content_type: {file.content_type}") try: unsanitized_filename = file.filename filename = os.path.basename(unsanitized_filename) file_path = f"{UPLOAD_DIR}/{filename}" contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) f.close() f = open(file_path, "rb") if collection_name == None: collection_name = calculate_sha256(f)[:63] f.close() loader, known_type = get_loader(filename, file.content_type, file_path) data = loader.load() try: result = store_data_in_vector_db(data, collection_name) if result: return { "status": True, "collection_name": collection_name, "filename": filename, "known_type": known_type, } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e, ) except Exception as e: log.exception(e) if "No pandoc was found" in str(e): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) class TextRAGForm(BaseModel): name: str content: str collection_name: Optional[str] = None @app.post("/text") def store_text( form_data: TextRAGForm, user=Depends(get_current_user), ): collection_name = form_data.collection_name if collection_name == None: collection_name = calculate_sha256_string(form_data.content) result = store_text_in_vector_db( form_data.content, metadata={"name": form_data.name, "created_by": user.id}, collection_name=collection_name, ) if result: return {"status": True, "collection_name": collection_name} else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(), ) @app.get("/scan") def scan_docs_dir(user=Depends(get_admin_user)): for path in Path(DOCS_DIR).rglob("./**/*"): try: if path.is_file() and not path.name.startswith("."): tags = extract_folders_after_data_docs(path) filename = path.name file_content_type = mimetypes.guess_type(path) f = open(path, "rb") collection_name = calculate_sha256(f)[:63] f.close() loader, known_type = get_loader( filename, file_content_type[0], str(path) ) data = loader.load() try: result = store_data_in_vector_db(data, collection_name) if result: sanitized_filename = sanitize_filename(filename) doc = Documents.get_doc_by_name(sanitized_filename) if doc == None: doc = Documents.insert_new_doc( user.id, DocumentForm( **{ "name": sanitized_filename, "title": filename, "collection_name": collection_name, "filename": filename, "content": ( json.dumps( { "tags": list( map( lambda name: {"name": name}, tags, ) ) } ) if len(tags) else "{}" ), } ), ) except Exception as e: log.exception(e) pass except Exception as e: log.exception(e) return True @app.get("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): CHROMA_CLIENT.reset() @app.get("/reset") def reset(user=Depends(get_admin_user)) -> bool: folder = f"{UPLOAD_DIR}" for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: log.error("Failed to delete %s. Reason: %s" % (file_path, e)) try: CHROMA_CLIENT.reset() except Exception as e: log.exception(e) return True if ENV == "dev": @app.get("/ef") async def get_embeddings(): return {"result": app.state.EMBEDDING_FUNCTION("hello world")} @app.get("/ef/{text}") async def get_embeddings_text(text: str): return {"result": app.state.EMBEDDING_FUNCTION(text)}