diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index a3e3c1134..3923eb459 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware import os, shutil, logging, re from pathlib import Path -from typing import List +from typing import List, Union, Sequence from chromadb.utils.batch_utils import create_batches @@ -58,6 +58,7 @@ from apps.rag.utils import ( query_doc_with_hybrid_search, query_collection, query_collection_with_hybrid_search, + search_web, ) from utils.misc import ( @@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm): url: str +class SearchForm(CollectionNameForm): + query: str + + @app.get("/") async def get_status(): return { @@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): ) -def get_web_loader(url: str): +def get_web_loader(url: Union[str, Sequence[str]]): # Check if the URL is valid - if isinstance(validators.url(url), validators.ValidationError): + if not validate_url(url): raise ValueError(ERROR_MESSAGES.INVALID_URL) - if not ENABLE_LOCAL_WEB_FETCH: - # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses - parsed_url = urllib.parse.urlparse(url) - # Get IPv4 and IPv6 addresses - ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) - # Check if any of the resolved addresses are private - # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader - for ip in ipv4_addresses: - if validators.ipv4(ip, private=True): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - for ip in ipv6_addresses: - if validators.ipv6(ip, private=True): - raise ValueError(ERROR_MESSAGES.INVALID_URL) return WebBaseLoader(url) +def validate_url(url: Union[str, Sequence[str]]): + if isinstance(url, str): + if isinstance(validators.url(url), validators.ValidationError): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + if not ENABLE_LOCAL_WEB_FETCH: + # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses + parsed_url = urllib.parse.urlparse(url) + # Get IPv4 and IPv6 addresses + ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) + # Check if any of the resolved addresses are private + # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader + for ip in ipv4_addresses: + if validators.ipv4(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + for ip in ipv6_addresses: + if validators.ipv6(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + return True + elif isinstance(url, Sequence): + return all(validate_url(u) for u in url) + else: + return False + + def resolve_hostname(hostname): # Get address information addr_info = socket.getaddrinfo(hostname, None) @@ -537,6 +553,32 @@ def resolve_hostname(hostname): return ipv4_addresses, ipv6_addresses +@app.post("/websearch") +def store_websearch(form_data: SearchForm, user=Depends(get_current_user)): + try: + web_results = search_web(form_data.query) + urls = [result.link for result in web_results] + loader = get_web_loader(urls) + data = loader.load() + + collection_name = form_data.collection_name + if collection_name == "": + collection_name = calculate_sha256_string(form_data.query)[:63] + + store_data_in_vector_db(data, collection_name, overwrite=True) + return { + "status": True, + "collection_name": collection_name, + "filenames": urls, + } + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: text_splitter = RecursiveCharacterTextSplitter( diff --git a/backend/apps/rag/search/google_pse.py b/backend/apps/rag/search/google_pse.py index ac637853d..a2a8f4640 100644 --- a/backend/apps/rag/search/google_pse.py +++ b/backend/apps/rag/search/google_pse.py @@ -30,14 +30,16 @@ def search_google_pse( "num": 5, } - response = requests.request("POST", url, headers=headers, params=params) + response = requests.request("GET", url, headers=headers, params=params) response.raise_for_status() json_response = response.json() results = json_response.get("items", []) return [ SearchResult( - link=result["url"], title=result.get("title"), snippet=result.get("snippet") + link=result["link"], + title=result.get("title"), + snippet=result.get("snippet"), ) for result in results ] diff --git a/backend/apps/rag/search/serper.py b/backend/apps/rag/search/serper.py index b40ea6053..c7c18a895 100644 --- a/backend/apps/rag/search/serper.py +++ b/backend/apps/rag/search/serper.py @@ -31,9 +31,9 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]: ) return [ SearchResult( - link=result["url"], + link=result["link"], title=result.get("title"), snippet=result.get("description"), ) - for result in results + for result in results[:5] ] diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index e106bd4c8..96401b277 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -545,21 +545,15 @@ def search_web(query: str) -> list[SearchResult]: Args: query (str): The query to search for """ - try: - if SEARXNG_QUERY_URL: - return search_searxng(SEARXNG_QUERY_URL, query) - elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID: - return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query) - elif BRAVE_SEARCH_API_KEY: - return search_brave(BRAVE_SEARCH_API_KEY, query) - elif SERPSTACK_API_KEY: - return search_serpstack( - SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS - ) - elif SERPER_API_KEY: - return search_serper(SERPER_API_KEY, query) - else: - raise Exception("No search engine API key found in environment variables") - except Exception as e: - log.error(f"Web search failed: {e}") - return [] + if SEARXNG_QUERY_URL: + return search_searxng(SEARXNG_QUERY_URL, query) + elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID: + return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query) + elif BRAVE_SEARCH_API_KEY: + return search_brave(BRAVE_SEARCH_API_KEY, query) + elif SERPSTACK_API_KEY: + return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS) + elif SERPER_API_KEY: + return search_serper(SERPER_API_KEY, query) + else: + raise Exception("No search engine API key found in environment variables")