diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 0e493eaaa..4bd5da86c 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper from apps.rag.search.serpstack import search_serpstack from apps.rag.search.serply import search_serply from apps.rag.search.duckduckgo import search_duckduckgo +from apps.rag.search.tavily import search_tavily from utils.misc import ( calculate_sha256, @@ -119,6 +120,7 @@ from config import ( SERPSTACK_HTTPS, SERPER_API_KEY, SERPLY_API_KEY, + TAVILY_API_KEY, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_EMBEDDING_OPENAI_BATCH_SIZE, @@ -172,6 +174,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS @@ -400,6 +403,7 @@ async def get_rag_config(user=Depends(get_admin_user)): "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -428,6 +432,7 @@ class WebSearchConfig(BaseModel): serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None serply_api_key: Optional[str] = None + tavily_api_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None @@ -479,6 +484,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests @@ -508,6 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -756,7 +763,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: - SERPSTACK_API_KEY - SERPER_API_KEY - SERPLY_API_KEY - + - TAVILY_API_KEY Args: query (str): The query to search for """ @@ -825,6 +832,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) + elif engine == "tavily": + if app.state.config.TAVILY_API_KEY: + return search_tavily( + app.state.config.TAVILY_API_KEY, + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + else: + raise Exception("No TAVILY_API_KEY found in environment variables") else: raise Exception("No search engine API key found in environment variables") diff --git a/backend/apps/rag/search/tavily.py b/backend/apps/rag/search/tavily.py new file mode 100644 index 000000000..b15d6ef9d --- /dev/null +++ b/backend/apps/rag/search/tavily.py @@ -0,0 +1,39 @@ +import logging + +import requests + +from apps.rag.search.main import SearchResult +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: + """Search using Tavily's Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Tavily Search API key + query (str): The query to search for + + Returns: + List[SearchResult]: A list of search results + """ + url = "https://api.tavily.com/search" + data = {"query": query, "api_key": api_key} + + response = requests.post(url, json=data) + response.raise_for_status() + + json_response = response.json() + + raw_search_results = json_response.get("results", []) + + return [ + SearchResult( + link=result["url"], + title=result.get("title", ""), + snippet=result.get("content"), + ) + for result in raw_search_results[:count] + ] diff --git a/backend/config.py b/backend/config.py index 30a23f29e..870c054f1 100644 --- a/backend/config.py +++ b/backend/config.py @@ -942,6 +942,11 @@ SERPLY_API_KEY = PersistentConfig( os.getenv("SERPLY_API_KEY", ""), ) +TAVILY_API_KEY = PersistentConfig( + "TAVILY_API_KEY", + "rag.web.search.tavily_api_key", + os.getenv("TAVILY_API_KEY", ""), +) RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( "RAG_WEB_SEARCH_RESULT_COUNT", diff --git a/src/lib/components/admin/Settings/WebSearch.svelte b/src/lib/components/admin/Settings/WebSearch.svelte index 864f61d0e..b9f43a9ab 100644 --- a/src/lib/components/admin/Settings/WebSearch.svelte +++ b/src/lib/components/admin/Settings/WebSearch.svelte @@ -18,7 +18,8 @@ 'serpstack', 'serper', 'serply', - 'duckduckgo' + 'duckduckgo', + 'tavily' ]; let youtubeLanguage = 'en'; @@ -214,6 +215,24 @@ + {:else if webConfig.search.engine === 'tavily'} +
+
+ {$i18n.t('Tavily API Key')} +
+ +
+
+ +
+
+
{/if} {/if}