Merge remote-tracking branch 'upstream/dev' into playwright

# Conflicts:
#	backend/open_webui/config.py
#	backend/open_webui/main.py
#	backend/open_webui/retrieval/web/utils.py
#	backend/open_webui/routers/retrieval.py
#	backend/open_webui/utils/middleware.py
#	pyproject.toml
This commit is contained in:
Rory
2025-02-14 20:48:22 -06:00
92 changed files with 2583 additions and 454 deletions

View File

@@ -9,6 +9,7 @@ from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def _parse_response(response):
result = {}
if "data" in response:
@@ -25,7 +26,8 @@ def _parse_response(response):
"summary": item.get("summary", ""),
"siteName": item.get("siteName", ""),
"siteIcon": item.get("siteIcon", ""),
"datePublished": item.get("datePublished", "") or item.get("dateLastCrawled", ""),
"datePublished": item.get("datePublished", "")
or item.get("dateLastCrawled", ""),
}
for item in webPages["value"]
]
@@ -42,17 +44,11 @@ def search_bocha(
query (str): The query to search for
"""
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = json.dumps({
"query": query,
"summary": True,
"freshness": "noLimit",
"count": count
})
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = json.dumps(
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
)
response = requests.post(url, headers=headers, data=payload, timeout=5)
response.raise_for_status()
@@ -63,10 +59,7 @@ def search_bocha(
return [
SearchResult(
link=result["url"],
title=result.get("name"),
snippet=result.get("summary")
link=result["url"], title=result.get("name"), snippet=result.get("summary")
)
for result in results.get("webpage", [])[:count]
for result in results.get("webpage", [])[:count]
]

View File

@@ -8,6 +8,7 @@ from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str,
search_engine_id: str,
@@ -46,12 +47,14 @@ def search_google_pse(
response.raise_for_status()
json_response = response.json()
results = json_response.get("items", [])
if results: # check if results are returned. If not, no more pages to fetch.
if results: # check if results are returned. If not, no more pages to fetch.
all_results.extend(results)
count -= len(results) # Decrement count by the number of results fetched in this page.
start_index += 10 # Increment start index for the next page
count -= len(
results
) # Decrement count by the number of results fetched in this page.
start_index += 10 # Increment start index for the next page
else:
break # No more results from Google PSE, break the loop
break # No more results from Google PSE, break the loop
if filter_list:
all_results = get_filtered_results(all_results, filter_list)

View File

@@ -0,0 +1,48 @@
import logging
from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpapi(
api_key: str,
engine: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serpapi.com API key
query (str): The query to search for
"""
url = "https://serpapi.com/search"
engine = engine or "google"
payload = {"engine": engine, "q": query, "api_key": api_key}
url = f"{url}?{urlencode(payload)}"
response = requests.request("GET", url)
json_response = response.json()
log.info(f"results from serpapi search: {json_response}")
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
for result in results[:count]
]

View File

@@ -1,4 +1,5 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult
@@ -8,7 +9,13 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
def search_tavily(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
# **kwargs,
) -> list[SearchResult]:
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
Args:
@@ -20,8 +27,8 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
response = requests.post(url, json=data)
include_domain = filter_list
response = requests.post(url, include_domain, json=data)
response.raise_for_status()
json_response = response.json()

View File

@@ -2,11 +2,15 @@ import asyncio
from datetime import datetime, time, timedelta
import socket
import ssl
import aiohttp
import asyncio
import urllib.parse
import certifi
import validators
from collections import defaultdict
from typing import AsyncIterator, Dict, List, Optional, Union, Sequence, Iterator
from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union
from langchain_community.document_loaders import (
WebBaseLoader,
@@ -230,6 +234,71 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader):
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def __init__(self, trust_env: bool = False, *args, **kwargs):
"""Initialize SafeWebBaseLoader
Args:
trust_env (bool, optional): set to True if using proxy to make web requests, for example
using http(s)_proxy environment variables. Defaults to False.
"""
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
kwargs: Dict = dict(
headers=self.session.headers,
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
async with session.get(
url, **(self.requests_kwargs | kwargs)
) as response:
if self.raise_for_status:
response.raise_for_status()
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
def _unpack_fetch_results(
self, results: Any, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Unpack fetch results into BeautifulSoup objects."""
from bs4 import BeautifulSoup
final_results = []
for i, result in enumerate(results):
url = urls[i]
if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
return final_results
async def ascrape_all(
self, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Async fetch all urls, then return soups for all results."""
results = await self.fetch_all(urls)
return self._unpack_fetch_results(results, urls, parser=parser)
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
@@ -245,6 +314,26 @@ class SafeWebBaseLoader(WebBaseLoader):
# Log the error and continue with the next URL
log.exception(e, "Error loading %s", path)
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
results = await self.ascrape_all(self.web_paths)
for path, soup in zip(self.web_paths, results):
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
async def aload(self) -> list[Document]:
"""Load data into Document objects."""
return [document async for document in self.alazy_load()]
RAG_WEB_LOADERS = defaultdict(lambda: SafeWebBaseLoader)
RAG_WEB_LOADERS["playwright"] = SafePlaywrightURLLoader
RAG_WEB_LOADERS["safe_web"] = SafeWebBaseLoader
@@ -253,16 +342,19 @@ def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
trust_env: bool = False,
):
# Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
web_loader_args = {
web_path=safe_urls,
"urls": safe_urls,
"verify_ssl": verify_ssl,
"requests_per_second": requests_per_second,
"continue_on_failure": True
"continue_on_failure": True,
trust_env=trust_env
}
if PLAYWRIGHT_WS_URI.value: