From 7b5f434a079b335be8452af7a05982a779973fcd Mon Sep 17 00:00:00 2001 From: Que Nguyen Date: Thu, 13 Jun 2024 07:14:48 +0700 Subject: [PATCH] Implement domain whitelisting for web search results --- backend/apps/rag/main.py | 10 +++++++++- backend/apps/rag/search/brave.py | 9 +++++---- backend/apps/rag/search/duckduckgo.py | 11 ++++++----- backend/apps/rag/search/google_pse.py | 9 +++++---- backend/apps/rag/search/main.py | 12 +++++++++++- backend/apps/rag/search/searxng.py | 5 +++-- backend/apps/rag/search/serper.py | 9 +++++---- backend/apps/rag/search/serply.py | 9 +++++---- backend/apps/rag/search/serpstack.py | 9 +++++---- backend/config.py | 9 +++++++++ 10 files changed, 63 insertions(+), 29 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 0e493eaaa..37da4db5a 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -111,6 +111,7 @@ from config import ( YOUTUBE_LOADER_LANGUAGE, ENABLE_RAG_WEB_SEARCH, RAG_WEB_SEARCH_ENGINE, + RAG_WEB_SEARCH_WHITE_LIST_DOMAINS, SEARXNG_QUERY_URL, GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, @@ -163,6 +164,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE +app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS = RAG_WEB_SEARCH_WHITE_LIST_DOMAINS app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY @@ -768,6 +770,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: app.state.config.SEARXNG_QUERY_URL, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS ) else: raise Exception("No SEARXNG_QUERY_URL found in environment variables") @@ -781,6 +784,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: app.state.config.GOOGLE_PSE_ENGINE_ID, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS ) else: raise Exception( @@ -792,6 +796,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: app.state.config.BRAVE_SEARCH_API_KEY, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS ) else: raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") @@ -801,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: app.state.config.SERPSTACK_API_KEY, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS, https_enabled=app.state.config.SERPSTACK_HTTPS, ) else: @@ -811,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: app.state.config.SERPER_API_KEY, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS ) else: raise Exception("No SERPER_API_KEY found in environment variables") @@ -820,11 +827,12 @@ def search_web(engine: str, query: str) -> list[SearchResult]: app.state.config.SERPLY_API_KEY, query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS ) else: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": - return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) + return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS) else: raise Exception("No search engine API key found in environment variables") diff --git a/backend/apps/rag/search/brave.py b/backend/apps/rag/search/brave.py index 4e0f56807..04cd18496 100644 --- a/backend/apps/rag/search/brave.py +++ b/backend/apps/rag/search/brave.py @@ -1,15 +1,15 @@ import logging - +from typing import List import requests -from apps.rag.search.main import SearchResult +from apps.rag.search.main import SearchResult, filter_by_whitelist from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: +def search_brave(api_key: str, query: str, whitelist:List[str], count: int) -> list[SearchResult]: """Search using Brave's Search API and return the results as a list of SearchResult objects. Args: @@ -29,9 +29,10 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: json_response = response.json() results = json_response.get("web", {}).get("results", []) + filtered_results = filter_by_whitelist(results, whitelist) return [ SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("snippet") ) - for result in results[:count] + for result in filtered_results[:count] ] diff --git a/backend/apps/rag/search/duckduckgo.py b/backend/apps/rag/search/duckduckgo.py index 188ae2bea..9342e53e4 100644 --- a/backend/apps/rag/search/duckduckgo.py +++ b/backend/apps/rag/search/duckduckgo.py @@ -1,6 +1,6 @@ import logging - -from apps.rag.search.main import SearchResult +from typing import List +from apps.rag.search.main import SearchResult, filter_by_whitelist from duckduckgo_search import DDGS from config import SRC_LOG_LEVELS @@ -8,7 +8,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_duckduckgo(query: str, count: int) -> list[SearchResult]: +def search_duckduckgo(query: str, count: int, whitelist:List[str]) -> list[SearchResult]: """ Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Args: @@ -41,6 +41,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]: snippet=result.get("body"), ) ) - print(results) + # print(results) + filtered_results = filter_by_whitelist(results, whitelist) # Return the list of search results - return results + return filtered_results diff --git a/backend/apps/rag/search/google_pse.py b/backend/apps/rag/search/google_pse.py index 7ff54c785..bc89b2f3a 100644 --- a/backend/apps/rag/search/google_pse.py +++ b/backend/apps/rag/search/google_pse.py @@ -1,9 +1,9 @@ import json import logging - +from typing import List import requests -from apps.rag.search.main import SearchResult +from apps.rag.search.main import SearchResult, filter_by_whitelist from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def search_google_pse( - api_key: str, search_engine_id: str, query: str, count: int + api_key: str, search_engine_id: str, query: str, count: int, whitelist:List[str] ) -> list[SearchResult]: """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. @@ -35,11 +35,12 @@ def search_google_pse( json_response = response.json() results = json_response.get("items", []) + filtered_results = filter_by_whitelist(results, whitelist) return [ SearchResult( link=result["link"], title=result.get("title"), snippet=result.get("snippet"), ) - for result in results + for result in filtered_results ] diff --git a/backend/apps/rag/search/main.py b/backend/apps/rag/search/main.py index b5478f949..612177581 100644 --- a/backend/apps/rag/search/main.py +++ b/backend/apps/rag/search/main.py @@ -1,8 +1,18 @@ from typing import Optional - +from urllib.parse import urlparse from pydantic import BaseModel +def filter_by_whitelist(results, whitelist): + if not whitelist: + return results + filtered_results = [] + for result in results: + domain = urlparse(result["url"]).netloc + if any(domain.endswith(whitelisted_domain) for whitelisted_domain in whitelist): + filtered_results.append(result) + return filtered_results + class SearchResult(BaseModel): link: str title: Optional[str] diff --git a/backend/apps/rag/search/searxng.py b/backend/apps/rag/search/searxng.py index c8ad88813..954aaf072 100644 --- a/backend/apps/rag/search/searxng.py +++ b/backend/apps/rag/search/searxng.py @@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def search_searxng( - query_url: str, query: str, count: int, **kwargs + query_url: str, query: str, count: int, whitelist:List[str], **kwargs ) -> List[SearchResult]: """ Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. @@ -78,9 +78,10 @@ def search_searxng( json_response = response.json() results = json_response.get("results", []) sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True) + filtered_results = filter_by_whitelist(sorted_results, whitelist) return [ SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("content") ) - for result in sorted_results[:count] + for result in filtered_results[:count] ] diff --git a/backend/apps/rag/search/serper.py b/backend/apps/rag/search/serper.py index 150da6e07..e12126a35 100644 --- a/backend/apps/rag/search/serper.py +++ b/backend/apps/rag/search/serper.py @@ -1,16 +1,16 @@ import json import logging - +from typing import List import requests -from apps.rag.search.main import SearchResult +from apps.rag.search.main import SearchResult, filter_by_whitelist from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: +def search_serper(api_key: str, query: str, count: int, whitelist:List[str]) -> list[SearchResult]: """Search using serper.dev's API and return the results as a list of SearchResult objects. Args: @@ -29,11 +29,12 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: results = sorted( json_response.get("organic", []), key=lambda x: x.get("position", 0) ) + filtered_results = filter_by_whitelist(results, whitelist) return [ SearchResult( link=result["link"], title=result.get("title"), snippet=result.get("description"), ) - for result in results[:count] + for result in filtered_results[:count] ] diff --git a/backend/apps/rag/search/serply.py b/backend/apps/rag/search/serply.py index fccf70ecd..e4040a848 100644 --- a/backend/apps/rag/search/serply.py +++ b/backend/apps/rag/search/serply.py @@ -1,10 +1,10 @@ import json import logging - +from typing import List import requests from urllib.parse import urlencode -from apps.rag.search.main import SearchResult +from apps.rag.search.main import SearchResult, filter_by_whitelist from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -15,6 +15,7 @@ def search_serply( api_key: str, query: str, count: int, + whitelist:List[str], hl: str = "us", limit: int = 10, device_type: str = "desktop", @@ -57,12 +58,12 @@ def search_serply( results = sorted( json_response.get("results", []), key=lambda x: x.get("realPosition", 0) ) - + filtered_results = filter_by_whitelist(results, whitelist) return [ SearchResult( link=result["link"], title=result.get("title"), snippet=result.get("description"), ) - for result in results[:count] + for result in filtered_results[:count] ] diff --git a/backend/apps/rag/search/serpstack.py b/backend/apps/rag/search/serpstack.py index 0d247d1ab..344e25073 100644 --- a/backend/apps/rag/search/serpstack.py +++ b/backend/apps/rag/search/serpstack.py @@ -1,9 +1,9 @@ import json import logging - +from typing import List import requests -from apps.rag.search.main import SearchResult +from apps.rag.search.main import SearchResult, filter_by_whitelist from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def search_serpstack( - api_key: str, query: str, count: int, https_enabled: bool = True + api_key: str, query: str, count: int, whitelist:List[str], https_enabled: bool = True ) -> list[SearchResult]: """Search using serpstack.com's and return the results as a list of SearchResult objects. @@ -35,9 +35,10 @@ def search_serpstack( results = sorted( json_response.get("organic_results", []), key=lambda x: x.get("position", 0) ) + filtered_results = filter_by_whitelist(results, whitelist) return [ SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("snippet") ) - for result in results[:count] + for result in filtered_results[:count] ] diff --git a/backend/config.py b/backend/config.py index 30a23f29e..6d145465a 100644 --- a/backend/config.py +++ b/backend/config.py @@ -894,6 +894,15 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig( os.getenv("RAG_WEB_SEARCH_ENGINE", ""), ) +RAG_WEB_SEARCH_WHITE_LIST_DOMAINS = PersistentConfig( + "RAG_WEB_SEARCH_WHITE_LIST_DOMAINS", + "rag.rag_web_search_white_list_domains", + [ + # "example.com", + # "anotherdomain.com", + ], +) + SEARXNG_QUERY_URL = PersistentConfig( "SEARXNG_QUERY_URL", "rag.web.search.searxng_query_url",