Merge pull request #3049 from que-nguyen/dev

Refactor URL validation function
This commit is contained in:
Timothy Jaeryang Baek 2024-06-12 01:36:34 -07:00 committed by GitHub
commit 5d3db15eca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,6 +8,7 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import requests
import os, shutil, logging, re import os, shutil, logging, re
from datetime import datetime from datetime import datetime
@ -716,36 +717,19 @@ def validate_url(url: Union[str, Sequence[str]]):
if isinstance(validators.url(url), validators.ValidationError): if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH: if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses # Check if the URL exists by making a HEAD request
parsed_url = urllib.parse.urlparse(url) try:
# Get IPv4 and IPv6 addresses response = requests.head(url, allow_redirects=True)
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) if response.status_code != 200:
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
except requests.exceptions.RequestException:
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True return True
elif isinstance(url, Sequence): elif isinstance(url, Sequence):
return all(validate_url(u) for u in url) return all(validate_url(u) for u in url)
else: else:
return False return False
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def search_web(engine: str, query: str) -> list[SearchResult]: def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects. """Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order: Will look for a search engine API key in environment variables in the following order: