mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
feat: add websearch endpoint to RAG API
fix: google PSE endpoint uses GET fix: google PSE returns link, not url fix: serper wrong field
This commit is contained in:
parent
501ff7a98b
commit
99e4edd364
@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
import os, shutil, logging, re
|
import os, shutil, logging, re
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Union, Sequence
|
||||||
|
|
||||||
from chromadb.utils.batch_utils import create_batches
|
from chromadb.utils.batch_utils import create_batches
|
||||||
|
|
||||||
@ -58,6 +58,7 @@ from apps.rag.utils import (
|
|||||||
query_doc_with_hybrid_search,
|
query_doc_with_hybrid_search,
|
||||||
query_collection,
|
query_collection,
|
||||||
query_collection_with_hybrid_search,
|
query_collection_with_hybrid_search,
|
||||||
|
search_web,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
|
|||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class SearchForm(CollectionNameForm):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get_status():
|
async def get_status():
|
||||||
return {
|
return {
|
||||||
@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_web_loader(url: str):
|
def get_web_loader(url: Union[str, Sequence[str]]):
|
||||||
# Check if the URL is valid
|
# Check if the URL is valid
|
||||||
if isinstance(validators.url(url), validators.ValidationError):
|
if not validate_url(url):
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
if not ENABLE_LOCAL_WEB_FETCH:
|
|
||||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
|
||||||
parsed_url = urllib.parse.urlparse(url)
|
|
||||||
# Get IPv4 and IPv6 addresses
|
|
||||||
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
|
||||||
# Check if any of the resolved addresses are private
|
|
||||||
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
|
||||||
for ip in ipv4_addresses:
|
|
||||||
if validators.ipv4(ip, private=True):
|
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
||||||
for ip in ipv6_addresses:
|
|
||||||
if validators.ipv6(ip, private=True):
|
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
||||||
return WebBaseLoader(url)
|
return WebBaseLoader(url)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: Union[str, Sequence[str]]):
|
||||||
|
if isinstance(url, str):
|
||||||
|
if isinstance(validators.url(url), validators.ValidationError):
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
if not ENABLE_LOCAL_WEB_FETCH:
|
||||||
|
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
# Get IPv4 and IPv6 addresses
|
||||||
|
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||||
|
# Check if any of the resolved addresses are private
|
||||||
|
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||||
|
for ip in ipv4_addresses:
|
||||||
|
if validators.ipv4(ip, private=True):
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
for ip in ipv6_addresses:
|
||||||
|
if validators.ipv6(ip, private=True):
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
return True
|
||||||
|
elif isinstance(url, Sequence):
|
||||||
|
return all(validate_url(u) for u in url)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def resolve_hostname(hostname):
|
def resolve_hostname(hostname):
|
||||||
# Get address information
|
# Get address information
|
||||||
addr_info = socket.getaddrinfo(hostname, None)
|
addr_info = socket.getaddrinfo(hostname, None)
|
||||||
@ -537,6 +553,32 @@ def resolve_hostname(hostname):
|
|||||||
return ipv4_addresses, ipv6_addresses
|
return ipv4_addresses, ipv6_addresses
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/websearch")
|
||||||
|
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
|
||||||
|
try:
|
||||||
|
web_results = search_web(form_data.query)
|
||||||
|
urls = [result.link for result in web_results]
|
||||||
|
loader = get_web_loader(urls)
|
||||||
|
data = loader.load()
|
||||||
|
|
||||||
|
collection_name = form_data.collection_name
|
||||||
|
if collection_name == "":
|
||||||
|
collection_name = calculate_sha256_string(form_data.query)[:63]
|
||||||
|
|
||||||
|
store_data_in_vector_db(data, collection_name, overwrite=True)
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"filenames": urls,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
@ -30,14 +30,16 @@ def search_google_pse(
|
|||||||
"num": 5,
|
"num": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.request("POST", url, headers=headers, params=params)
|
response = requests.request("GET", url, headers=headers, params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
results = json_response.get("items", [])
|
results = json_response.get("items", [])
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
link=result["link"],
|
||||||
|
title=result.get("title"),
|
||||||
|
snippet=result.get("snippet"),
|
||||||
)
|
)
|
||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
|
@ -31,9 +31,9 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
|
|||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result["url"],
|
link=result["link"],
|
||||||
title=result.get("title"),
|
title=result.get("title"),
|
||||||
snippet=result.get("description"),
|
snippet=result.get("description"),
|
||||||
)
|
)
|
||||||
for result in results
|
for result in results[:5]
|
||||||
]
|
]
|
||||||
|
@ -545,21 +545,15 @@ def search_web(query: str) -> list[SearchResult]:
|
|||||||
Args:
|
Args:
|
||||||
query (str): The query to search for
|
query (str): The query to search for
|
||||||
"""
|
"""
|
||||||
try:
|
if SEARXNG_QUERY_URL:
|
||||||
if SEARXNG_QUERY_URL:
|
return search_searxng(SEARXNG_QUERY_URL, query)
|
||||||
return search_searxng(SEARXNG_QUERY_URL, query)
|
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
|
||||||
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
|
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
|
||||||
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
|
elif BRAVE_SEARCH_API_KEY:
|
||||||
elif BRAVE_SEARCH_API_KEY:
|
return search_brave(BRAVE_SEARCH_API_KEY, query)
|
||||||
return search_brave(BRAVE_SEARCH_API_KEY, query)
|
elif SERPSTACK_API_KEY:
|
||||||
elif SERPSTACK_API_KEY:
|
return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
|
||||||
return search_serpstack(
|
elif SERPER_API_KEY:
|
||||||
SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS
|
return search_serper(SERPER_API_KEY, query)
|
||||||
)
|
else:
|
||||||
elif SERPER_API_KEY:
|
raise Exception("No search engine API key found in environment variables")
|
||||||
return search_serper(SERPER_API_KEY, query)
|
|
||||||
else:
|
|
||||||
raise Exception("No search engine API key found in environment variables")
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Web search failed: {e}")
|
|
||||||
return []
|
|
||||||
|
Loading…
Reference in New Issue
Block a user