diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 27d7f47dd..35549be84 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -97,9 +97,11 @@ from config import ( ENABLE_RAG_LOCAL_WEB_FETCH, YOUTUBE_LOADER_LANGUAGE, ENABLE_RAG_WEB_SEARCH, + RAG_WEB_SEARCH_ENGINE, SEARXNG_QUERY_URL, GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, + BRAVE_SEARCH_API_KEY, SERPSTACK_API_KEY, SERPSTACK_HTTPS, SERPER_API_KEY, @@ -145,9 +147,12 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH +app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE + app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID +app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY @@ -351,23 +356,25 @@ async def get_rag_config(user=Depends(get_admin_user)): "chunk_size": app.state.config.CHUNK_SIZE, "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { + "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "search": { "enable": app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - } + }, }, } @@ -384,9 +391,11 @@ class YoutubeLoaderConfig(BaseModel): class WebSearchConfig(BaseModel): enable: bool + engine: Optional[str] = None searxng_query_url: Optional[str] = None google_pse_api_key: Optional[str] = None google_pse_engine_id: Optional[str] = None + brave_search_api_key: Optional[str] = None serpstack_api_key: Optional[str] = None serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None @@ -394,11 +403,16 @@ class WebSearchConfig(BaseModel): concurrent_requests: Optional[int] = None +class WebConfig(BaseModel): + search: WebSearchConfig + web_loader_ssl_verification: Optional[bool] = None + + class ConfigUpdateForm(BaseModel): pdf_extract_images: Optional[bool] = None chunk: Optional[ChunkParamUpdateForm] = None - web_loader_ssl_verification: Optional[bool] = None youtube: Optional[YoutubeLoaderConfig] = None + web: Optional[WebConfig] = None @app.post("/config/update") @@ -409,35 +423,36 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ else app.state.config.PDF_EXTRACT_IMAGES ) - app.state.config.CHUNK_SIZE = ( - form_data.chunk.chunk_size - if form_data.chunk is not None - else app.state.config.CHUNK_SIZE - ) + if form_data.chunk is not None: + app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size + app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap - app.state.config.CHUNK_OVERLAP = ( - form_data.chunk.chunk_overlap - if form_data.chunk is not None - else app.state.config.CHUNK_OVERLAP - ) + if form_data.youtube is not None: + app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language + app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation - app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( - form_data.web_loader_ssl_verification - if form_data.web_loader_ssl_verification != None - else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION - ) + if form_data.web is not None: + app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + form_data.web.web_loader_ssl_verification + ) - app.state.config.YOUTUBE_LOADER_LANGUAGE = ( - form_data.youtube.language - if form_data.youtube is not None - else app.state.config.YOUTUBE_LOADER_LANGUAGE - ) - - app.state.YOUTUBE_LOADER_TRANSLATION = ( - form_data.youtube.translation - if form_data.youtube is not None - else app.state.YOUTUBE_LOADER_TRANSLATION - ) + app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enable + app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine + app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url + app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key + app.state.config.GOOGLE_PSE_ENGINE_ID = ( + form_data.web.search.google_pse_engine_id + ) + app.state.config.BRAVE_SEARCH_API_KEY = ( + form_data.web.search.brave_search_api_key + ) + app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key + app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https + app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count + app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( + form_data.web.search.concurrent_requests + ) return { "status": True, @@ -446,11 +461,26 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "chunk_size": app.state.config.CHUNK_SIZE, "chunk_overlap": app.state.config.CHUNK_OVERLAP, }, - "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "youtube": { "language": app.state.config.YOUTUBE_LOADER_LANGUAGE, "translation": app.state.YOUTUBE_LOADER_TRANSLATION, }, + "web": { + "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "search": { + "enable": app.state.config.ENABLE_RAG_WEB_SEARCH, + "engine": app.state.config.RAG_WEB_SEARCH_ENGINE, + "searxng_query_url": app.state.config.SEARXNG_QUERY_URL, + "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY, + "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID, + "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY, + "serpstack_api_key": app.state.config.SERPSTACK_API_KEY, + "serpstack_https": app.state.config.SERPSTACK_HTTPS, + "serper_api_key": app.state.config.SERPER_API_KEY, + "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + }, + }, } @@ -690,7 +720,9 @@ def resolve_hostname(hostname): def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): try: try: - web_results = search_web(form_data.query) + web_results = search_web( + app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query + ) except Exception as e: log.exception(e) raise HTTPException( diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index d5a826acc..c809d0169 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -538,7 +538,7 @@ class RerankCompressor(BaseDocumentCompressor): return final_results -def search_web(query: str) -> list[SearchResult]: +def search_web(engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: - SEARXNG_QUERY_URL @@ -552,15 +552,34 @@ def search_web(query: str) -> list[SearchResult]: """ # TODO: add playwright to search the web - if SEARXNG_QUERY_URL: - return search_searxng(SEARXNG_QUERY_URL, query) - elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID: - return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query) - elif BRAVE_SEARCH_API_KEY: - return search_brave(BRAVE_SEARCH_API_KEY, query) - elif SERPSTACK_API_KEY: - return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS) - elif SERPER_API_KEY: - return search_serper(SERPER_API_KEY, query) + if engine == "searxng": + if SEARXNG_QUERY_URL: + return search_searxng(SEARXNG_QUERY_URL, query) + else: + raise Exception("No SEARXNG_QUERY_URL found in environment variables") + elif engine == "google_pse": + if GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID: + return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query) + else: + raise Exception( + "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables" + ) + elif engine == "brave": + if BRAVE_SEARCH_API_KEY: + return search_brave(BRAVE_SEARCH_API_KEY, query) + else: + raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") + elif engine == "serpstack": + if SERPSTACK_API_KEY: + return search_serpstack( + SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS + ) + else: + raise Exception("No SERPSTACK_API_KEY found in environment variables") + elif engine == "serper": + if SERPER_API_KEY: + return search_serper(SERPER_API_KEY, query) + else: + raise Exception("No SERPER_API_KEY found in environment variables") else: raise Exception("No search engine API key found in environment variables") diff --git a/backend/config.py b/backend/config.py index 8f97458f2..32820072f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -773,6 +773,11 @@ ENABLE_RAG_WEB_SEARCH = PersistentConfig( os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true", ) +RAG_WEB_SEARCH_ENGINE = PersistentConfig( + "RAG_WEB_SEARCH_ENGINE", + "rag.web.search.engine", + os.getenv("RAG_WEB_SEARCH_ENGINE", ""), +) SEARXNG_QUERY_URL = PersistentConfig( "SEARXNG_QUERY_URL", diff --git a/src/lib/components/documents/Settings/WebParams.svelte b/src/lib/components/documents/Settings/WebParams.svelte index 2ca2f3ace..5b15e3e41 100644 --- a/src/lib/components/documents/Settings/WebParams.svelte +++ b/src/lib/components/documents/Settings/WebParams.svelte @@ -1,5 +1,6 @@