diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index da70aa8e7..bfd102afa 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -21,18 +21,25 @@ def search_tavily( Args: api_key (str): A Tavily Search API key query (str): The query to search for + count (int): The maximum number of results to return Returns: list[SearchResult]: A list of search results """ url = "https://api.tavily.com/search" - data = {"query": query, "api_key": api_key} - response = requests.post(url, json=data) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + data = {"query": query, "max_results": count} + response = requests.post(url, headers=headers, json=data) response.raise_for_status() json_response = response.json() - raw_search_results = json_response.get("results", []) + results = json_response.get("results", []) + if filter_list: + results = get_filtered_results(results, filter_list) return [ SearchResult( @@ -40,5 +47,5 @@ def search_tavily( title=result.get("title", ""), snippet=result.get("content"), ) - for result in raw_search_results[:count] + for result in results ]