Implement domain whitelisting for web search results

This commit is contained in:
Que Nguyen 2024-06-13 07:14:48 +07:00 committed by GitHub
parent a382e82dec
commit 7b5f434a07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 63 additions and 29 deletions

View File

@ -111,6 +111,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH, ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS,
SEARXNG_QUERY_URL, SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY, GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID, GOOGLE_PSE_ENGINE_ID,
@ -163,6 +164,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS = RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
@ -768,6 +770,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SEARXNG_QUERY_URL, app.state.config.SEARXNG_QUERY_URL,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
) )
else: else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables") raise Exception("No SEARXNG_QUERY_URL found in environment variables")
@ -781,6 +784,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.GOOGLE_PSE_ENGINE_ID, app.state.config.GOOGLE_PSE_ENGINE_ID,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
) )
else: else:
raise Exception( raise Exception(
@ -792,6 +796,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.BRAVE_SEARCH_API_KEY, app.state.config.BRAVE_SEARCH_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
) )
else: else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
@ -801,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPSTACK_API_KEY, app.state.config.SERPSTACK_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS,
https_enabled=app.state.config.SERPSTACK_HTTPS, https_enabled=app.state.config.SERPSTACK_HTTPS,
) )
else: else:
@ -811,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPER_API_KEY, app.state.config.SERPER_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
) )
else: else:
raise Exception("No SERPER_API_KEY found in environment variables") raise Exception("No SERPER_API_KEY found in environment variables")
@ -820,11 +827,12 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPLY_API_KEY, app.state.config.SERPLY_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
) )
else: else:
raise Exception("No SERPLY_API_KEY found in environment variables") raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo": elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_WHITE_LIST_DOMAINS)
else: else:
raise Exception("No search engine API key found in environment variables") raise Exception("No search engine API key found in environment variables")

View File

@ -1,15 +1,15 @@
import logging import logging
from typing import List
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, filter_by_whitelist
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: def search_brave(api_key: str, query: str, whitelist:List[str], count: int) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects. """Search using Brave's Search API and return the results as a list of SearchResult objects.
Args: Args:
@ -29,9 +29,10 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response = response.json() json_response = response.json()
results = json_response.get("web", {}).get("results", []) results = json_response.get("web", {}).get("results", [])
filtered_results = filter_by_whitelist(results, whitelist)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
) )
for result in results[:count] for result in filtered_results[:count]
] ]

View File

@ -1,6 +1,6 @@
import logging import logging
from typing import List
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, filter_by_whitelist
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -8,7 +8,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]: def search_duckduckgo(query: str, count: int, whitelist:List[str]) -> list[SearchResult]:
""" """
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args: Args:
@ -41,6 +41,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet=result.get("body"), snippet=result.get("body"),
) )
) )
print(results) # print(results)
filtered_results = filter_by_whitelist(results, whitelist)
# Return the list of search results # Return the list of search results
return results return filtered_results

View File

@ -1,9 +1,9 @@
import json import json
import logging import logging
from typing import List
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, filter_by_whitelist
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse( def search_google_pse(
api_key: str, search_engine_id: str, query: str, count: int api_key: str, search_engine_id: str, query: str, count: int, whitelist:List[str]
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
@ -35,11 +35,12 @@ def search_google_pse(
json_response = response.json() json_response = response.json()
results = json_response.get("items", []) results = json_response.get("items", [])
filtered_results = filter_by_whitelist(results, whitelist)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
title=result.get("title"), title=result.get("title"),
snippet=result.get("snippet"), snippet=result.get("snippet"),
) )
for result in results for result in filtered_results
] ]

View File

@ -1,8 +1,18 @@
from typing import Optional from typing import Optional
from urllib.parse import urlparse
from pydantic import BaseModel from pydantic import BaseModel
def filter_by_whitelist(results, whitelist):
if not whitelist:
return results
filtered_results = []
for result in results:
domain = urlparse(result["url"]).netloc
if any(domain.endswith(whitelisted_domain) for whitelisted_domain in whitelist):
filtered_results.append(result)
return filtered_results
class SearchResult(BaseModel): class SearchResult(BaseModel):
link: str link: str
title: Optional[str] title: Optional[str]

View File

@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng( def search_searxng(
query_url: str, query: str, count: int, **kwargs query_url: str, query: str, count: int, whitelist:List[str], **kwargs
) -> List[SearchResult]: ) -> List[SearchResult]:
""" """
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
@ -78,9 +78,10 @@ def search_searxng(
json_response = response.json() json_response = response.json()
results = json_response.get("results", []) results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True) sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
filtered_results = filter_by_whitelist(sorted_results, whitelist)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content") link=result["url"], title=result.get("title"), snippet=result.get("content")
) )
for result in sorted_results[:count] for result in filtered_results[:count]
] ]

View File

@ -1,16 +1,16 @@
import json import json
import logging import logging
from typing import List
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, filter_by_whitelist
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: def search_serper(api_key: str, query: str, count: int, whitelist:List[str]) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
Args: Args:
@ -29,11 +29,12 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results = sorted( results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0) json_response.get("organic", []), key=lambda x: x.get("position", 0)
) )
filtered_results = filter_by_whitelist(results, whitelist)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
title=result.get("title"), title=result.get("title"),
snippet=result.get("description"), snippet=result.get("description"),
) )
for result in results[:count] for result in filtered_results[:count]
] ]

View File

@ -1,10 +1,10 @@
import json import json
import logging import logging
from typing import List
import requests import requests
from urllib.parse import urlencode from urllib.parse import urlencode
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, filter_by_whitelist
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -15,6 +15,7 @@ def search_serply(
api_key: str, api_key: str,
query: str, query: str,
count: int, count: int,
whitelist:List[str],
hl: str = "us", hl: str = "us",
limit: int = 10, limit: int = 10,
device_type: str = "desktop", device_type: str = "desktop",
@ -57,12 +58,12 @@ def search_serply(
results = sorted( results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0) json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
) )
filtered_results = filter_by_whitelist(results, whitelist)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
title=result.get("title"), title=result.get("title"),
snippet=result.get("description"), snippet=result.get("description"),
) )
for result in results[:count] for result in filtered_results[:count]
] ]

View File

@ -1,9 +1,9 @@
import json import json
import logging import logging
from typing import List
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, filter_by_whitelist
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack( def search_serpstack(
api_key: str, query: str, count: int, https_enabled: bool = True api_key: str, query: str, count: int, whitelist:List[str], https_enabled: bool = True
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects. """Search using serpstack.com's and return the results as a list of SearchResult objects.
@ -35,9 +35,10 @@ def search_serpstack(
results = sorted( results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0) json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
) )
filtered_results = filter_by_whitelist(results, whitelist)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
) )
for result in results[:count] for result in filtered_results[:count]
] ]

View File

@ -894,6 +894,15 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os.getenv("RAG_WEB_SEARCH_ENGINE", ""), os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
) )
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS = PersistentConfig(
"RAG_WEB_SEARCH_WHITE_LIST_DOMAINS",
"rag.rag_web_search_white_list_domains",
[
# "example.com",
# "anotherdomain.com",
],
)
SEARXNG_QUERY_URL = PersistentConfig( SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL", "SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url", "rag.web.search.searxng_query_url",