Merge pull request from que-nguyen/searxng

Domain whitelisting for web search results
This commit is contained in:
Timothy Jaeryang Baek 2024-06-17 14:30:17 -07:00 committed by GitHub
commit 20f052eb37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 68 additions and 24 deletions

View File

@ -112,6 +112,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH, ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
SEARXNG_QUERY_URL, SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY, GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID, GOOGLE_PSE_ENGINE_ID,
@ -165,6 +166,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
@ -775,6 +777,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SEARXNG_QUERY_URL, app.state.config.SEARXNG_QUERY_URL,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
) )
else: else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables") raise Exception("No SEARXNG_QUERY_URL found in environment variables")
@ -788,6 +791,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.GOOGLE_PSE_ENGINE_ID, app.state.config.GOOGLE_PSE_ENGINE_ID,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
) )
else: else:
raise Exception( raise Exception(
@ -799,6 +803,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.BRAVE_SEARCH_API_KEY, app.state.config.BRAVE_SEARCH_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
) )
else: else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
@ -808,6 +813,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPSTACK_API_KEY, app.state.config.SERPSTACK_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
https_enabled=app.state.config.SERPSTACK_HTTPS, https_enabled=app.state.config.SERPSTACK_HTTPS,
) )
else: else:
@ -818,6 +824,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPER_API_KEY, app.state.config.SERPER_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
) )
else: else:
raise Exception("No SERPER_API_KEY found in environment variables") raise Exception("No SERPER_API_KEY found in environment variables")
@ -827,11 +834,12 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPLY_API_KEY, app.state.config.SERPLY_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
) )
else: else:
raise Exception("No SERPLY_API_KEY found in environment variables") raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo": elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST)
elif engine == "tavily": elif engine == "tavily":
if app.state.config.TAVILY_API_KEY: if app.state.config.TAVILY_API_KEY:
return search_tavily( return search_tavily(

View File

@ -1,15 +1,15 @@
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: def search_brave(api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects. """Search using Brave's Search API and return the results as a list of SearchResult objects.
Args: Args:
@ -29,6 +29,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response = response.json() json_response = response.json()
results = json_response.get("web", {}).get("results", []) results = json_response.get("web", {}).get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")

View File

@ -1,6 +1,6 @@
import logging import logging
from typing import List, Optional
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -8,7 +8,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]: def search_duckduckgo(query: str, count: int, filter_list: Optional[List[str]] = None) -> list[SearchResult]:
""" """
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args: Args:
@ -41,6 +41,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet=result.get("body"), snippet=result.get("body"),
) )
) )
print(results) if filter_list:
results = get_filtered_results(results, filter_list)
# Return the list of search results # Return the list of search results
return results return results

View File

@ -1,9 +1,9 @@
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse( def search_google_pse(
api_key: str, search_engine_id: str, query: str, count: int api_key: str, search_engine_id: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
@ -35,6 +35,8 @@ def search_google_pse(
json_response = response.json() json_response = response.json()
results = json_response.get("items", []) results = json_response.get("items", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],

View File

@ -1,8 +1,18 @@
from typing import Optional from typing import Optional
from urllib.parse import urlparse
from pydantic import BaseModel from pydantic import BaseModel
def get_filtered_results(results, filter_list):
if not filter_list:
return results
filtered_results = []
for result in results:
domain = urlparse(result["url"]).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)
return filtered_results
class SearchResult(BaseModel): class SearchResult(BaseModel):
link: str link: str
title: Optional[str] title: Optional[str]

View File

@ -1,9 +1,9 @@
import logging import logging
import requests import requests
from typing import List from typing import List, Optional
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng( def search_searxng(
query_url: str, query: str, count: int, **kwargs query_url: str, query: str, count: int, filter_list: Optional[List[str]] = None, **kwargs
) -> List[SearchResult]: ) -> List[SearchResult]:
""" """
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
@ -78,6 +78,8 @@ def search_searxng(
json_response = response.json() json_response = response.json()
results = json_response.get("results", []) results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True) sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, whitelist)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content") link=result["url"], title=result.get("title"), snippet=result.get("content")

View File

@ -1,16 +1,16 @@
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: def search_serper(api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
Args: Args:
@ -29,6 +29,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results = sorted( results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0) json_response.get("organic", []), key=lambda x: x.get("position", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],

View File

@ -1,10 +1,10 @@
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from urllib.parse import urlencode from urllib.parse import urlencode
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -19,6 +19,7 @@ def search_serply(
limit: int = 10, limit: int = 10,
device_type: str = "desktop", device_type: str = "desktop",
proxy_location: str = "US", proxy_location: str = "US",
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
@ -57,7 +58,8 @@ def search_serply(
results = sorted( results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0) json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],

View File

@ -1,9 +1,9 @@
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack( def search_serpstack(
api_key: str, query: str, count: int, https_enabled: bool = True api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None, https_enabled: bool = True
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects. """Search using serpstack.com's and return the results as a list of SearchResult objects.
@ -35,6 +35,8 @@ def search_serpstack(
results = sorted( results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0) json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")

View File

@ -903,6 +903,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os.getenv("RAG_WEB_SEARCH_ENGINE", ""), os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
) )
# You can provide a list of your own websites to filter after performing a web search.
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
"rag.rag.web.search.domain.filter_list",
[
# "wikipedia.com",
# "wikimedia.org",
# "wikidata.org",
],
)
SEARXNG_QUERY_URL = PersistentConfig( SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL", "SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url", "rag.web.search.searxng_query_url",