mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into playwright
# Conflicts: # backend/open_webui/config.py # backend/open_webui/main.py # backend/open_webui/retrieval/web/utils.py # backend/open_webui/routers/retrieval.py # backend/open_webui/utils/middleware.py # pyproject.toml
This commit is contained in:
@@ -9,6 +9,7 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def _parse_response(response):
|
||||
result = {}
|
||||
if "data" in response:
|
||||
@@ -25,7 +26,8 @@ def _parse_response(response):
|
||||
"summary": item.get("summary", ""),
|
||||
"siteName": item.get("siteName", ""),
|
||||
"siteIcon": item.get("siteIcon", ""),
|
||||
"datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""),
|
||||
"datePublished": item.get("datePublished", "")
|
||||
or item.get("dateLastCrawled", ""),
|
||||
}
|
||||
for item in webPages["value"]
|
||||
]
|
||||
@@ -42,17 +44,11 @@ def search_bocha(
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = json.dumps({
|
||||
"query": query,
|
||||
"summary": True,
|
||||
"freshness": "noLimit",
|
||||
"count": count
|
||||
})
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = json.dumps(
|
||||
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
|
||||
)
|
||||
|
||||
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||
response.raise_for_status()
|
||||
@@ -63,10 +59,7 @@ def search_bocha(
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("name"),
|
||||
snippet=result.get("summary")
|
||||
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
||||
)
|
||||
for result in results.get("webpage", [])[:count]
|
||||
for result in results.get("webpage", [])[:count]
|
||||
]
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_google_pse(
|
||||
api_key: str,
|
||||
search_engine_id: str,
|
||||
@@ -46,12 +47,14 @@ def search_google_pse(
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
all_results.extend(results)
|
||||
count -= len(results) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
count -= len(
|
||||
results
|
||||
) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
else:
|
||||
break # No more results from Google PSE, break the loop
|
||||
break # No more results from Google PSE, break the loop
|
||||
|
||||
if filter_list:
|
||||
all_results = get_filtered_results(all_results, filter_list)
|
||||
|
||||
48
backend/open_webui/retrieval/web/serpapi.py
Normal file
48
backend/open_webui/retrieval/web/serpapi.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serpapi(
|
||||
api_key: str,
|
||||
engine: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serpapi.com API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://serpapi.com/search"
|
||||
|
||||
engine = engine or "google"
|
||||
|
||||
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||
|
||||
url = f"{url}?{urlencode(payload)}"
|
||||
response = requests.request("GET", url)
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serpapi search: {json_response}")
|
||||
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
@@ -8,7 +9,13 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
def search_tavily(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
# **kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
@@ -20,8 +27,8 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
data = {"query": query, "api_key": api_key}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
include_domain = filter_list
|
||||
response = requests.post(url, include_domain, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
@@ -2,11 +2,15 @@ import asyncio
|
||||
from datetime import datetime, time, timedelta
|
||||
import socket
|
||||
import ssl
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
import certifi
|
||||
import validators
|
||||
from collections import defaultdict
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
|
||||
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
@@ -230,6 +234,71 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
|
||||
def __init__(self, trust_env: bool = False, *args, **kwargs):
|
||||
"""Initialize SafeWebBaseLoader
|
||||
Args:
|
||||
trust_env (bool, optional): set to True if using proxy to make web requests, for example
|
||||
using http(s)_proxy environment variables. Defaults to False.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.trust_env = trust_env
|
||||
|
||||
async def _fetch(
|
||||
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
||||
) -> str:
|
||||
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
|
||||
for i in range(retries):
|
||||
try:
|
||||
kwargs: Dict = dict(
|
||||
headers=self.session.headers,
|
||||
cookies=self.session.cookies.get_dict(),
|
||||
)
|
||||
if not self.session.verify:
|
||||
kwargs["ssl"] = False
|
||||
|
||||
async with session.get(
|
||||
url, **(self.requests_kwargs | kwargs)
|
||||
) as response:
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
return await response.text()
|
||||
except aiohttp.ClientConnectionError as e:
|
||||
if i == retries - 1:
|
||||
raise
|
||||
else:
|
||||
log.warning(
|
||||
f"Error fetching {url} with attempt "
|
||||
f"{i + 1}/{retries}: {e}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(cooldown * backoff**i)
|
||||
raise ValueError("retry count exceeded")
|
||||
|
||||
def _unpack_fetch_results(
|
||||
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Unpack fetch results into BeautifulSoup objects."""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
final_results = []
|
||||
for i, result in enumerate(results):
|
||||
url = urls[i]
|
||||
if parser is None:
|
||||
if url.endswith(".xml"):
|
||||
parser = "xml"
|
||||
else:
|
||||
parser = self.default_parser
|
||||
self._check_parser(parser)
|
||||
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
|
||||
return final_results
|
||||
|
||||
async def ascrape_all(
|
||||
self, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
"""Async fetch all urls, then return soups for all results."""
|
||||
results = await self.fetch_all(urls)
|
||||
return self._unpack_fetch_results(results, urls, parser=parser)
|
||||
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load text from the url(s) in web_path with error handling."""
|
||||
for path in self.web_paths:
|
||||
@@ -245,6 +314,26 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
# Log the error and continue with the next URL
|
||||
log.exception(e, "Error loading %s", path)
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async lazy load text from the url(s) in web_path."""
|
||||
results = await self.ascrape_all(self.web_paths)
|
||||
for path, soup in zip(self.web_paths, results):
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
async def aload(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
RAG_WEB_LOADERS = defaultdict(lambda: SafeWebBaseLoader)
|
||||
RAG_WEB_LOADERS["playwright"] = SafePlaywrightURLLoader
|
||||
RAG_WEB_LOADERS["safe_web"] = SafeWebBaseLoader
|
||||
@@ -253,16 +342,19 @@ def get_web_loader(
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
requests_per_second: int = 2,
|
||||
trust_env: bool = False,
|
||||
):
|
||||
# Check if the URLs are valid
|
||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||
|
||||
|
||||
web_loader_args = {
|
||||
web_path=safe_urls,
|
||||
"urls": safe_urls,
|
||||
"verify_ssl": verify_ssl,
|
||||
"requests_per_second": requests_per_second,
|
||||
"continue_on_failure": True
|
||||
"continue_on_failure": True,
|
||||
trust_env=trust_env
|
||||
}
|
||||
|
||||
if PLAYWRIGHT_WS_URI.value:
|
||||
|
||||
Reference in New Issue
Block a user