diff --git a/backend/apps/rag/search/serpstack.py b/backend/apps/rag/search/serpstack.py index 344e25073..f69baaf92 100644 --- a/backend/apps/rag/search/serpstack.py +++ b/backend/apps/rag/search/serpstack.py @@ -1,9 +1,9 @@ import json import logging -from typing import List +from typing import List, Optional import requests -from apps.rag.search.main import SearchResult, filter_by_whitelist +from apps.rag.search.main import SearchResult, get_filtered_results from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) def search_serpstack( - api_key: str, query: str, count: int, whitelist:List[str], https_enabled: bool = True + api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None, https_enabled: bool = True ) -> list[SearchResult]: """Search using serpstack.com's and return the results as a list of SearchResult objects. @@ -35,10 +35,11 @@ def search_serpstack( results = sorted( json_response.get("organic_results", []), key=lambda x: x.get("position", 0) ) - filtered_results = filter_by_whitelist(results, whitelist) + if filter_list: + results = get_filtered_results(results, filter_list) return [ SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("snippet") ) - for result in filtered_results[:count] + for result in results[:count] ]