From bcb84235b17007f6d176074ee1ee327fbcdc4844 Mon Sep 17 00:00:00 2001 From: Que Nguyen Date: Mon, 17 Jun 2024 14:37:52 +0700 Subject: [PATCH] Set filter_list as optional param in serply.py --- backend/apps/rag/search/serply.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/backend/apps/rag/search/serply.py b/backend/apps/rag/search/serply.py index e4040a848..24b249b73 100644 --- a/backend/apps/rag/search/serply.py +++ b/backend/apps/rag/search/serply.py @@ -1,10 +1,10 @@ import json import logging -from typing import List +from typing import List, Optional import requests from urllib.parse import urlencode -from apps.rag.search.main import SearchResult, filter_by_whitelist +from apps.rag.search.main import SearchResult, get_filtered_results from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -15,11 +15,11 @@ def search_serply( api_key: str, query: str, count: int, - whitelist:List[str], hl: str = "us", limit: int = 10, device_type: str = "desktop", proxy_location: str = "US", + filter_list: Optional[List[str]] = None, ) -> list[SearchResult]: """Search using serper.dev's API and return the results as a list of SearchResult objects. @@ -58,12 +58,13 @@ def search_serply( results = sorted( json_response.get("results", []), key=lambda x: x.get("realPosition", 0) ) - filtered_results = filter_by_whitelist(results, whitelist) + if filter_list: + results = get_filtered_results(results, filter_list) return [ SearchResult( link=result["link"], title=result.get("title"), snippet=result.get("description"), ) - for result in filtered_results[:count] + for result in results[:count] ]