diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index cc468725d..7839b715e 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import requests from open_webui.retrieval.web.main import SearchResult @@ -8,7 +9,13 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: +def search_tavily( + api_key: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, + # **kwargs, +) -> list[SearchResult]: """Search using Tavily's Search API and return the results as a list of SearchResult objects. Args: @@ -20,8 +27,8 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: """ url = "https://api.tavily.com/search" data = {"query": query, "api_key": api_key} - - response = requests.post(url, json=data) + include_domain = filter_list + response = requests.post(url, include_domain, json=data) response.raise_for_status() json_response = response.json()