From a02139ba9df0513696c3cb89aecf037e19aee4d2 Mon Sep 17 00:00:00 2001 From: Que Nguyen Date: Mon, 17 Jun 2024 14:34:17 +0700 Subject: [PATCH] Set filter_list as optional param in brave.py --- backend/apps/rag/search/brave.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/backend/apps/rag/search/brave.py b/backend/apps/rag/search/brave.py index 04cd18496..a20a2cde8 100644 --- a/backend/apps/rag/search/brave.py +++ b/backend/apps/rag/search/brave.py @@ -1,15 +1,15 @@ import logging -from typing import List +from typing import List, Optional import requests -from apps.rag.search.main import SearchResult, filter_by_whitelist +from apps.rag.search.main import SearchResult, get_filtered_results from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_brave(api_key: str, query: str, whitelist:List[str], count: int) -> list[SearchResult]: +def search_brave(api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None) -> list[SearchResult]: """Search using Brave's Search API and return the results as a list of SearchResult objects. Args: @@ -29,10 +29,12 @@ def search_brave(api_key: str, query: str, whitelist:List[str], count: int) -> l json_response = response.json() results = json_response.get("web", {}).get("results", []) - filtered_results = filter_by_whitelist(results, whitelist) + if filter_list: + results = get_filtered_results(results, filter_list) + return [ SearchResult( link=result["url"], title=result.get("title"), snippet=result.get("snippet") ) - for result in filtered_results[:count] + for result in results[:count] ]