feat: add websearch endpoint to RAG API

fix: google PSE endpoint uses GET

fix: google PSE returns link, not url

fix: serper wrong field
This commit is contained in:
Jun Siang Cheah 2024-05-06 16:39:25 +08:00
parent 501ff7a98b
commit 99e4edd364
4 changed files with 76 additions and 38 deletions

View File

@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging, re import os, shutil, logging, re
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Union, Sequence
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
@ -58,6 +58,7 @@ from apps.rag.utils import (
query_doc_with_hybrid_search, query_doc_with_hybrid_search,
query_collection, query_collection,
query_collection_with_hybrid_search, query_collection_with_hybrid_search,
search_web,
) )
from utils.misc import ( from utils.misc import (
@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
url: str url: str
class SearchForm(CollectionNameForm):
query: str
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return { return {
@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
) )
def get_web_loader(url: str): def get_web_loader(url: Union[str, Sequence[str]]):
# Check if the URL is valid # Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError): if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url) return WebBaseLoader(url)
def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True
elif isinstance(url, Sequence):
return all(validate_url(u) for u in url)
else:
return False
def resolve_hostname(hostname): def resolve_hostname(hostname):
# Get address information # Get address information
addr_info = socket.getaddrinfo(hostname, None) addr_info = socket.getaddrinfo(hostname, None)
@ -537,6 +553,32 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses return ipv4_addresses, ipv6_addresses
@app.post("/websearch")
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
try:
web_results = search_web(form_data.query)
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
data = loader.load()
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.query)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(

View File

@ -30,14 +30,16 @@ def search_google_pse(
"num": 5, "num": 5,
} }
response = requests.request("POST", url, headers=headers, params=params) response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
json_response = response.json() json_response = response.json()
results = json_response.get("items", []) results = json_response.get("items", [])
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
) )
for result in results for result in results
] ]

View File

@ -31,9 +31,9 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
) )
return [ return [
SearchResult( SearchResult(
link=result["url"], link=result["link"],
title=result.get("title"), title=result.get("title"),
snippet=result.get("description"), snippet=result.get("description"),
) )
for result in results for result in results[:5]
] ]

View File

@ -545,21 +545,15 @@ def search_web(query: str) -> list[SearchResult]:
Args: Args:
query (str): The query to search for query (str): The query to search for
""" """
try: if SEARXNG_QUERY_URL:
if SEARXNG_QUERY_URL: return search_searxng(SEARXNG_QUERY_URL, query)
return search_searxng(SEARXNG_QUERY_URL, query) elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID: return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query) elif BRAVE_SEARCH_API_KEY:
elif BRAVE_SEARCH_API_KEY: return search_brave(BRAVE_SEARCH_API_KEY, query)
return search_brave(BRAVE_SEARCH_API_KEY, query) elif SERPSTACK_API_KEY:
elif SERPSTACK_API_KEY: return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
return search_serpstack( elif SERPER_API_KEY:
SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS return search_serper(SERPER_API_KEY, query)
) else:
elif SERPER_API_KEY: raise Exception("No search engine API key found in environment variables")
return search_serper(SERPER_API_KEY, query)
else:
raise Exception("No search engine API key found in environment variables")
except Exception as e:
log.error(f"Web search failed: {e}")
return []