feat: add websearch endpoint to RAG API

fix: google PSE endpoint uses GET

fix: google PSE returns link, not url

fix: serper wrong field
This commit is contained in:
Jun Siang Cheah 2024-05-06 16:39:25 +08:00
parent 501ff7a98b
commit 99e4edd364
4 changed files with 76 additions and 38 deletions

View File

@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging, re
from pathlib import Path
from typing import List
from typing import List, Union, Sequence
from chromadb.utils.batch_utils import create_batches
@ -58,6 +58,7 @@ from apps.rag.utils import (
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
search_web,
)
from utils.misc import (
@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
url: str
class SearchForm(CollectionNameForm):
query: str
@app.get("/")
async def get_status():
return {
@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
)
def get_web_loader(url: str):
def get_web_loader(url: Union[str, Sequence[str]]):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True
elif isinstance(url, Sequence):
return all(validate_url(u) for u in url)
else:
return False
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
@ -537,6 +553,32 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
@app.post("/websearch")
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
try:
web_results = search_web(form_data.query)
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
data = loader.load()
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.query)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(

View File

@ -30,14 +30,16 @@ def search_google_pse(
"num": 5,
}
response = requests.request("POST", url, headers=headers, params=params)
response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("items", [])
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results
]

View File

@ -31,9 +31,9 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
)
return [
SearchResult(
link=result["url"],
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in results
for result in results[:5]
]

View File

@ -545,21 +545,15 @@ def search_web(query: str) -> list[SearchResult]:
Args:
query (str): The query to search for
"""
try:
if SEARXNG_QUERY_URL:
return search_searxng(SEARXNG_QUERY_URL, query)
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
elif BRAVE_SEARCH_API_KEY:
return search_brave(BRAVE_SEARCH_API_KEY, query)
elif SERPSTACK_API_KEY:
return search_serpstack(
SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS
)
elif SERPER_API_KEY:
return search_serper(SERPER_API_KEY, query)
else:
raise Exception("No search engine API key found in environment variables")
except Exception as e:
log.error(f"Web search failed: {e}")
return []
if SEARXNG_QUERY_URL:
return search_searxng(SEARXNG_QUERY_URL, query)
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
elif BRAVE_SEARCH_API_KEY:
return search_brave(BRAVE_SEARCH_API_KEY, query)
elif SERPSTACK_API_KEY:
return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
elif SERPER_API_KEY:
return search_serper(SERPER_API_KEY, query)
else:
raise Exception("No search engine API key found in environment variables")