From 1163745a035b74b2cee7a97b13e37eafacb25dde Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 12 Jun 2024 11:08:05 -0700 Subject: [PATCH] revert --- backend/apps/rag/main.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 113e60ea8..0e493eaaa 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -717,13 +717,18 @@ def validate_url(url: Union[str, Sequence[str]]): if isinstance(validators.url(url), validators.ValidationError): raise ValueError(ERROR_MESSAGES.INVALID_URL) if not ENABLE_RAG_LOCAL_WEB_FETCH: - # Check if the URL exists by making a HEAD request - try: - response = requests.head(url, allow_redirects=True) - if response.status_code != 200: + # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses + parsed_url = urllib.parse.urlparse(url) + # Get IPv4 and IPv6 addresses + ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) + # Check if any of the resolved addresses are private + # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader + for ip in ipv4_addresses: + if validators.ipv4(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + for ip in ipv6_addresses: + if validators.ipv6(ip, private=True): raise ValueError(ERROR_MESSAGES.INVALID_URL) - except requests.exceptions.RequestException: - raise ValueError(ERROR_MESSAGES.INVALID_URL) return True elif isinstance(url, Sequence): return all(validate_url(u) for u in url) @@ -731,6 +736,17 @@ def validate_url(url: Union[str, Sequence[str]]): return False +def resolve_hostname(hostname): + # Get address information + addr_info = socket.getaddrinfo(hostname, None) + + # Extract IP addresses from address information + ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] + ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] + + return ipv4_addresses, ipv6_addresses + + def search_web(engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: