From eb7bba81fe6a4ee6ec8baa57d21963b1b8bd650e Mon Sep 17 00:00:00 2001 From: Que Nguyen Date: Wed, 12 Jun 2024 08:15:04 +0700 Subject: [PATCH] Refactor URL validation function - The check for private IP addresses often did not yield the expected results, especially with errors like: `[Errno -2] Name or service not known`. - Removed the check for private IP addresses in the URL validation process. - Simplified the `validate_url` function to focus on validating the URL format and checking the existence of the URL using a HEAD request. --- backend/apps/rag/main.py | 30 +++++++----------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index cf2e8b3e6..5167fcf6c 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -8,6 +8,7 @@ from fastapi import ( Form, ) from fastapi.middleware.cors import CORSMiddleware +import requests import os, shutil, logging, re from datetime import datetime @@ -716,36 +717,19 @@ def validate_url(url: Union[str, Sequence[str]]): if isinstance(validators.url(url), validators.ValidationError): raise ValueError(ERROR_MESSAGES.INVALID_URL) if not ENABLE_RAG_LOCAL_WEB_FETCH: - # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses - parsed_url = urllib.parse.urlparse(url) - # Get IPv4 and IPv6 addresses - ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) - # Check if any of the resolved addresses are private - # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader - for ip in ipv4_addresses: - if validators.ipv4(ip, private=True): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - for ip in ipv6_addresses: - if validators.ipv6(ip, private=True): + # Check if the URL exists by making a HEAD request + try: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: raise ValueError(ERROR_MESSAGES.INVALID_URL) + except requests.exceptions.RequestException: + raise ValueError(ERROR_MESSAGES.INVALID_URL) return True elif isinstance(url, Sequence): return all(validate_url(u) for u in url) else: return False - -def resolve_hostname(hostname): - # Get address information - addr_info = socket.getaddrinfo(hostname, None) - - # Extract IP addresses from address information - ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] - ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] - - return ipv4_addresses, ipv6_addresses - - def search_web(engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: