mirror of
https://github.com/open-webui/open-webui
synced 2025-01-18 00:30:51 +00:00
feat: add websearch endpoint to RAG API
fix: google PSE endpoint uses GET fix: google PSE returns link, not url fix: serper wrong field
This commit is contained in:
parent
501ff7a98b
commit
99e4edd364
@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
import os, shutil, logging, re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Union, Sequence
|
||||
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
@ -58,6 +58,7 @@ from apps.rag.utils import (
|
||||
query_doc_with_hybrid_search,
|
||||
query_collection,
|
||||
query_collection_with_hybrid_search,
|
||||
search_web,
|
||||
)
|
||||
|
||||
from utils.misc import (
|
||||
@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
|
||||
url: str
|
||||
|
||||
|
||||
class SearchForm(CollectionNameForm):
|
||||
query: str
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {
|
||||
@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
)
|
||||
|
||||
|
||||
def get_web_loader(url: str):
|
||||
def get_web_loader(url: Union[str, Sequence[str]]):
|
||||
# Check if the URL is valid
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
if not validate_url(url):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
if not ENABLE_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
# Get IPv4 and IPv6 addresses
|
||||
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||
# Check if any of the resolved addresses are private
|
||||
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||
for ip in ipv4_addresses:
|
||||
if validators.ipv4(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
for ip in ipv6_addresses:
|
||||
if validators.ipv6(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return WebBaseLoader(url)
|
||||
|
||||
|
||||
def validate_url(url: Union[str, Sequence[str]]):
|
||||
if isinstance(url, str):
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
if not ENABLE_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
# Get IPv4 and IPv6 addresses
|
||||
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||
# Check if any of the resolved addresses are private
|
||||
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||
for ip in ipv4_addresses:
|
||||
if validators.ipv4(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
for ip in ipv6_addresses:
|
||||
if validators.ipv6(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return True
|
||||
elif isinstance(url, Sequence):
|
||||
return all(validate_url(u) for u in url)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
@ -537,6 +553,32 @@ def resolve_hostname(hostname):
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
@app.post("/websearch")
|
||||
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
|
||||
try:
|
||||
web_results = search_web(form_data.query)
|
||||
urls = [result.link for result in web_results]
|
||||
loader = get_web_loader(urls)
|
||||
data = loader.load()
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == "":
|
||||
collection_name = calculate_sha256_string(form_data.query)[:63]
|
||||
|
||||
store_data_in_vector_db(data, collection_name, overwrite=True)
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
|
@ -30,14 +30,16 @@ def search_google_pse(
|
||||
"num": 5,
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, params=params)
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
@ -31,9 +31,9 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
|
||||
)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results
|
||||
for result in results[:5]
|
||||
]
|
||||
|
@ -545,21 +545,15 @@ def search_web(query: str) -> list[SearchResult]:
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
"""
|
||||
try:
|
||||
if SEARXNG_QUERY_URL:
|
||||
return search_searxng(SEARXNG_QUERY_URL, query)
|
||||
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
|
||||
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
|
||||
elif BRAVE_SEARCH_API_KEY:
|
||||
return search_brave(BRAVE_SEARCH_API_KEY, query)
|
||||
elif SERPSTACK_API_KEY:
|
||||
return search_serpstack(
|
||||
SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS
|
||||
)
|
||||
elif SERPER_API_KEY:
|
||||
return search_serper(SERPER_API_KEY, query)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
except Exception as e:
|
||||
log.error(f"Web search failed: {e}")
|
||||
return []
|
||||
if SEARXNG_QUERY_URL:
|
||||
return search_searxng(SEARXNG_QUERY_URL, query)
|
||||
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
|
||||
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
|
||||
elif BRAVE_SEARCH_API_KEY:
|
||||
return search_brave(BRAVE_SEARCH_API_KEY, query)
|
||||
elif SERPSTACK_API_KEY:
|
||||
return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
|
||||
elif SERPER_API_KEY:
|
||||
return search_serper(SERPER_API_KEY, query)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
Loading…
Reference in New Issue
Block a user