From c761e4fd087643188f347289b7217692c55b29aa Mon Sep 17 00:00:00 2001 From: orenzhang Date: Mon, 10 Mar 2025 22:27:31 +0800 Subject: [PATCH] feat(trace): opentelemetry instrument --- backend/open_webui/config.py | 4 +- backend/open_webui/env.py | 9 + backend/open_webui/main.py | 14 +- .../open_webui/retrieval/loaders/tavily.py | 31 ++-- .../retrieval/vector/dbs/opensearch.py | 91 +++++----- backend/open_webui/retrieval/web/utils.py | 17 +- backend/open_webui/utils/trace/__init__.py | 0 backend/open_webui/utils/trace/constants.py | 26 +++ backend/open_webui/utils/trace/exporters.py | 31 ++++ .../open_webui/utils/trace/instrumentors.py | 155 ++++++++++++++++++ backend/open_webui/utils/trace/setup.py | 24 +++ backend/requirements.txt | 16 +- 12 files changed, 337 insertions(+), 81 deletions(-) create mode 100644 backend/open_webui/utils/trace/__init__.py create mode 100644 backend/open_webui/utils/trace/constants.py create mode 100644 backend/open_webui/utils/trace/exporters.py create mode 100644 backend/open_webui/utils/trace/instrumentors.py create mode 100644 backend/open_webui/utils/trace/setup.py diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 1cb6ab56a..871264718 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1580,7 +1580,9 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) # OpenSearch OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true" -OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", "false").lower() == "true" +OPENSEARCH_CERT_VERIFY = ( + os.environ.get("OPENSEARCH_CERT_VERIFY", "false").lower() == "true" +) OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 2abf65924..5668c2b40 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -442,3 +442,12 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders" ) AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] + +#################################### +# OPENTELEMETRY +#################################### + +OT_ENABLED = os.environ.get("OT_ENABLED", "false").lower() == "true" +OT_SERVICE_NAME = os.environ.get("OT_SERVICE_NAME", "open-webui") +OT_HOST = os.environ.get("OT_HOST", "http://localhost:4317") +OT_TOKEN = os.environ.get("OT_TOKEN", "") diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a453df0d7..4d3c1b683 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -330,6 +330,7 @@ from open_webui.env import ( BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, + OT_ENABLED, ) @@ -356,7 +357,7 @@ from open_webui.utils.oauth import OAuthManager from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.tasks import stop_task, list_tasks # Import from tasks.py - +from open_webui.utils.trace.setup import setup if SAFE_MODE: print("SAFE MODE ENABLED") @@ -426,6 +427,17 @@ app.state.config = AppConfig(redis_url=REDIS_URL) app.state.WEBUI_NAME = WEBUI_NAME app.state.LICENSE_METADATA = None + +######################################## +# +# OPENTELEMETRY +# +######################################## + +if OT_ENABLED: + setup(app) + + ######################################## # # OLLAMA diff --git a/backend/open_webui/retrieval/loaders/tavily.py b/backend/open_webui/retrieval/loaders/tavily.py index b96396eba..15a3d7f97 100644 --- a/backend/open_webui/retrieval/loaders/tavily.py +++ b/backend/open_webui/retrieval/loaders/tavily.py @@ -9,18 +9,20 @@ from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) + class TavilyLoader(BaseLoader): """Extract web page content from URLs using Tavily Extract API. - + This is a LangChain document loader that uses Tavily's Extract API to retrieve content from web pages and return it as Document objects. - + Args: urls: URL or list of URLs to extract content from. api_key: The Tavily API key. extract_depth: Depth of extraction, either "basic" or "advanced". continue_on_failure: Whether to continue if extraction of a URL fails. """ + def __init__( self, urls: Union[str, List[str]], @@ -29,13 +31,13 @@ class TavilyLoader(BaseLoader): continue_on_failure: bool = True, ) -> None: """Initialize Tavily Extract client. - + Args: urls: URL or list of URLs to extract content from. api_key: The Tavily API key. include_images: Whether to include images in the extraction. extract_depth: Depth of extraction, either "basic" or "advanced". - advanced extraction retrieves more data, including tables and + advanced extraction retrieves more data, including tables and embedded content, with higher success but may increase latency. basic costs 1 credit per 5 successful URL extractions, advanced costs 2 credits per 5 successful URL extractions. @@ -43,35 +45,28 @@ class TavilyLoader(BaseLoader): """ if not urls: raise ValueError("At least one URL must be provided.") - + self.api_key = api_key self.urls = urls if isinstance(urls, list) else [urls] self.extract_depth = extract_depth self.continue_on_failure = continue_on_failure self.api_url = "https://api.tavily.com/extract" - + def lazy_load(self) -> Iterator[Document]: """Extract and yield documents from the URLs using Tavily Extract API.""" batch_size = 20 for i in range(0, len(self.urls), batch_size): - batch_urls = self.urls[i:i + batch_size] + batch_urls = self.urls[i : i + batch_size] try: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" + "Authorization": f"Bearer {self.api_key}", } # Use string for single URL, array for multiple URLs urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls - payload = { - "urls": urls_param, - "extract_depth": self.extract_depth - } + payload = {"urls": urls_param, "extract_depth": self.extract_depth} # Make the API call - response = requests.post( - self.api_url, - headers=headers, - json=payload - ) + response = requests.post(self.api_url, headers=headers, json=payload) response.raise_for_status() response_data = response.json() # Process successful results @@ -95,4 +90,4 @@ class TavilyLoader(BaseLoader): if self.continue_on_failure: log.error(f"Error extracting content from batch {batch_urls}: {e}") else: - raise e \ No newline at end of file + raise e diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 4844f7d4e..99567c84e 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -21,14 +21,14 @@ class OpenSearchClient: verify_certs=OPENSEARCH_CERT_VERIFY, http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), ) - + def _get_index_name(self, collection_name: str) -> str: return f"{self.index_prefix}_{collection_name}" def _result_to_get_result(self, result) -> GetResult: if not result["hits"]["hits"]: return None - + ids = [] documents = [] metadatas = [] @@ -43,7 +43,7 @@ class OpenSearchClient: def _result_to_search_result(self, result) -> SearchResult: if not result["hits"]["hits"]: return None - + ids = [] distances = [] documents = [] @@ -56,16 +56,15 @@ class OpenSearchClient: metadatas.append(hit["_source"].get("metadata")) return SearchResult( - ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas] + ids=[ids], + distances=[distances], + documents=[documents], + metadatas=[metadatas], ) def _create_index(self, collection_name: str, dimension: int): body = { - "settings": { - "index": { - "knn": True - } - }, + "settings": {"index": {"knn": True}}, "mappings": { "properties": { "id": {"type": "keyword"}, @@ -81,13 +80,13 @@ class OpenSearchClient: "parameters": { "ef_construction": 128, "m": 16, - } + }, }, }, "text": {"type": "text"}, "metadata": {"type": "object"}, } - } + }, } self.client.indices.create( index=self._get_index_name(collection_name), body=body @@ -100,9 +99,7 @@ class OpenSearchClient: def has_collection(self, collection_name: str) -> bool: # has_collection here means has index. # We are simply adapting to the norms of the other DBs. - return self.client.indices.exists( - index=self._get_index_name(collection_name) - ) + return self.client.indices.exists(index=self._get_index_name(collection_name)) def delete_collection(self, collection_name: str): # delete_collection here means delete index. @@ -115,33 +112,30 @@ class OpenSearchClient: try: if not self.has_collection(collection_name): return None - + query = { "size": limit, "_source": ["text", "metadata"], "query": { "script_score": { - "query": { - "match_all": {} - }, + "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0", "params": { - "field": "vector", - "query_value": vectors[0] + "field": "vector", + "query_value": vectors[0], }, # Assuming single query vector }, } }, } - + result = self.client.search( - index=self._get_index_name(collection_name), - body=query + index=self._get_index_name(collection_name), body=query ) return self._result_to_search_result(result) - + except Exception as e: return None @@ -152,20 +146,14 @@ class OpenSearchClient: return None query_body = { - "query": { - "bool": { - "filter": [] - } - }, + "query": {"bool": {"filter": []}}, "_source": ["text", "metadata"], } for field, value in filter.items(): - query_body["query"]["bool"]["filter"].append({ - "match": { - "metadata." + str(field): value - } - }) + query_body["query"]["bool"]["filter"].append( + {"match": {"metadata." + str(field): value}} + ) size = limit if limit else 10 @@ -201,9 +189,9 @@ class OpenSearchClient: for batch in self._create_batches(items): actions = [ { - "_op_type": "index", + "_op_type": "index", "_index": self._get_index_name(collection_name), - "_id": item["id"], + "_id": item["id"], "_source": { "vector": item["vector"], "text": item["text"], @@ -222,9 +210,9 @@ class OpenSearchClient: for batch in self._create_batches(items): actions = [ { - "_op_type": "update", + "_op_type": "update", "_index": self._get_index_name(collection_name), - "_id": item["id"], + "_id": item["id"], "doc": { "vector": item["vector"], "text": item["text"], @@ -236,7 +224,12 @@ class OpenSearchClient: ] bulk(self.client, actions) - def delete(self, collection_name: str, ids: Optional[list[str]] = None, filter: Optional[dict] = None): + def delete( + self, + collection_name: str, + ids: Optional[list[str]] = None, + filter: Optional[dict] = None, + ): if ids: actions = [ { @@ -249,20 +242,16 @@ class OpenSearchClient: bulk(self.client, actions) elif filter: query_body = { - "query": { - "bool": { - "filter": [] - } - }, + "query": {"bool": {"filter": []}}, } for field, value in filter.items(): - query_body["query"]["bool"]["filter"].append({ - "match": { - "metadata." + str(field): value - } - }) - self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body) - + query_body["query"]["bool"]["filter"].append( + {"match": {"metadata." + str(field): value}} + ) + self.client.delete_by_query( + index=self._get_index_name(collection_name), body=query_body + ) + def reset(self): indices = self.client.indices.get(index=f"{self.index_prefix}_*") for index in indices: diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 65654d8e8..538321372 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -136,18 +136,18 @@ class RateLimitMixin: self.last_request_time = datetime.now() -class URLProcessingMixin: +class URLProcessingMixin: def _verify_ssl_cert(self, url: str) -> bool: """Verify SSL certificate for a URL.""" return verify_ssl_cert(url) - + async def _safe_process_url(self, url: str) -> bool: """Perform safety checks before processing a URL.""" if self.verify_ssl and not self._verify_ssl_cert(url): raise ValueError(f"SSL certificate verification failed for {url}") await self._wait_for_rate_limit() return True - + def _safe_process_url_sync(self, url: str) -> bool: """Synchronous version of safety checks.""" if self.verify_ssl and not self._verify_ssl_cert(url): @@ -286,7 +286,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): proxy["server"] = env_proxy_server else: proxy = {"server": env_proxy_server} - + # Store parameters for creating TavilyLoader instances self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths] self.api_key = api_key @@ -295,7 +295,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): self.verify_ssl = verify_ssl self.trust_env = trust_env self.proxy = proxy - + # Add rate limiting self.requests_per_second = requests_per_second self.last_request_time = None @@ -329,7 +329,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): log.exception(e, "Error extracting content from URLs") else: raise e - + async def alazy_load(self) -> AsyncIterator[Document]: """Async version with rate limiting and SSL verification.""" valid_urls = [] @@ -341,13 +341,13 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): log.warning(f"SSL verification failed for {url}: {str(e)}") if not self.continue_on_failure: raise e - + if not valid_urls: if self.continue_on_failure: log.warning("No valid URLs to process after SSL verification") return raise ValueError("No valid URLs to process after SSL verification") - + try: loader = TavilyLoader( urls=valid_urls, @@ -477,7 +477,6 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing await browser.close() - class SafeWebBaseLoader(WebBaseLoader): """WebBaseLoader with enhanced error handling for URLs.""" diff --git a/backend/open_webui/utils/trace/__init__.py b/backend/open_webui/utils/trace/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/open_webui/utils/trace/constants.py b/backend/open_webui/utils/trace/constants.py new file mode 100644 index 000000000..6ef511f93 --- /dev/null +++ b/backend/open_webui/utils/trace/constants.py @@ -0,0 +1,26 @@ +from opentelemetry.semconv.trace import SpanAttributes as _SpanAttributes + +# Span Tags +SPAN_DB_TYPE = "mysql" +SPAN_REDIS_TYPE = "redis" +SPAN_DURATION = "duration" +SPAN_SQL_STR = "sql" +SPAN_SQL_EXPLAIN = "explain" +SPAN_ERROR_TYPE = "error" + + +class SpanAttributes(_SpanAttributes): + """ + Span Attributes + """ + + DB_INSTANCE = "db.instance" + DB_TYPE = "db.type" + DB_IP = "db.ip" + DB_PORT = "db.port" + ERROR_KIND = "error.kind" + ERROR_OBJECT = "error.object" + ERROR_MESSAGE = "error.message" + RESULT_CODE = "result.code" + RESULT_MESSAGE = "result.message" + RESULT_ERRORS = "result.errors" diff --git a/backend/open_webui/utils/trace/exporters.py b/backend/open_webui/utils/trace/exporters.py new file mode 100644 index 000000000..4bf166e65 --- /dev/null +++ b/backend/open_webui/utils/trace/exporters.py @@ -0,0 +1,31 @@ +import threading + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import BatchSpanProcessor + + +class LazyBatchSpanProcessor(BatchSpanProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.done = True + with self.condition: + self.condition.notify_all() + self.worker_thread.join() + self.done = False + self.worker_thread = None + + def on_end(self, span: ReadableSpan) -> None: + if self.worker_thread is None: + self.worker_thread = threading.Thread( + name=self.__class__.__name__, target=self.worker, daemon=True + ) + self.worker_thread.start() + super().on_end(span) + + def shutdown(self) -> None: + self.done = True + with self.condition: + self.condition.notify_all() + if self.worker_thread: + self.worker_thread.join() + self.span_exporter.shutdown() diff --git a/backend/open_webui/utils/trace/instrumentors.py b/backend/open_webui/utils/trace/instrumentors.py new file mode 100644 index 000000000..33998db39 --- /dev/null +++ b/backend/open_webui/utils/trace/instrumentors.py @@ -0,0 +1,155 @@ +import logging +import traceback +from typing import Collection + +from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi +from opentelemetry.instrumentation.httpx import ( + HTTPXClientInstrumentor, + RequestInfo, + ResponseInfo, +) +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.logging import LoggingInstrumentor +from opentelemetry.instrumentation.redis import RedisInstrumentor +from opentelemetry.instrumentation.requests import RequestsInstrumentor +from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor +from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor +from opentelemetry.trace import Span, StatusCode +from redis import Redis +from requests import PreparedRequest, Response + +from open_webui.utils.trace.constants import SPAN_REDIS_TYPE, SpanAttributes + +from open_webui.env import SRC_LOG_LEVELS + +logger = logging.getLogger(__name__) +logger.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +def requests_hook(span: Span, request: PreparedRequest): + """ + Http Request Hook + """ + + span.update_name(f"{request.method} {request.url}") + span.set_attributes( + attributes={ + SpanAttributes.HTTP_URL: request.url, + SpanAttributes.HTTP_METHOD: request.method, + } + ) + + +def response_hook(span: Span, request: PreparedRequest, response: Response): + """ + HTTP Response Hook + """ + + span.set_attributes( + attributes={ + SpanAttributes.HTTP_STATUS_CODE: response.status_code, + } + ) + span.set_status(StatusCode.ERROR if response.status_code >= 400 else StatusCode.OK) + + +def redis_request_hook(span: Span, instance: Redis, args, kwargs): + """ + Redis Request Hook + """ + + try: + connection_kwargs: dict = instance.connection_pool.connection_kwargs + host = connection_kwargs.get("host") + port = connection_kwargs.get("port") + db = connection_kwargs.get("db") + span.set_attributes( + { + SpanAttributes.DB_INSTANCE: f"{host}/{db}", + SpanAttributes.DB_NAME: f"{host}/{db}", + SpanAttributes.DB_TYPE: SPAN_REDIS_TYPE, + SpanAttributes.DB_PORT: port, + SpanAttributes.DB_IP: host, + SpanAttributes.DB_STATEMENT: " ".join([str(i) for i in args]), + SpanAttributes.DB_OPERATION: str(args[0]), + } + ) + except Exception: # pylint: disable=W0718 + logger.error(traceback.format_exc()) + + +def httpx_request_hook(span: Span, request: RequestInfo): + """ + HTTPX Request Hook + """ + + span.update_name(f"{request.method.decode()} {str(request.url)}") + span.set_attributes( + attributes={ + SpanAttributes.HTTP_URL: str(request.url), + SpanAttributes.HTTP_METHOD: request.method.decode(), + } + ) + + +def httpx_response_hook(span: Span, request: RequestInfo, response: ResponseInfo): + """ + HTTPX Response Hook + """ + + span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.status_code) + span.set_status( + StatusCode.ERROR + if response.status_code >= status.HTTP_400_BAD_REQUEST + else StatusCode.OK + ) + + +async def httpx_async_request_hook(span, request): + """ + Async Request Hook + """ + + httpx_request_hook(span, request) + + +async def httpx_async_response_hook(span, request, response): + """ + Async Response Hook + """ + + httpx_response_hook(span, request, response) + + +class Instrumentor(BaseInstrumentor): + """ + Instrument OT + """ + + def __init__(self, app): + self.app = app + + def instrumentation_dependencies(self) -> Collection[str]: + return [] + + def _instrument(self, **kwargs): + instrument_fastapi(app=self.app) + SQLAlchemyInstrumentor().instrument() + RedisInstrumentor().instrument(request_hook=redis_request_hook) + RequestsInstrumentor().instrument( + request_hook=requests_hook, response_hook=response_hook + ) + LoggingInstrumentor().instrument() + HTTPXClientInstrumentor().instrument( + request_hook=httpx_request_hook, + response_hook=httpx_response_hook, + async_request_hook=httpx_async_request_hook, + async_response_hook=httpx_async_response_hook, + ) + AioHttpClientInstrumentor().instrument() + + def _uninstrument(self, **kwargs): + if getattr(self, "instrumentors", None) is None: + return + for instrumentor in self.instrumentors: + instrumentor.uninstrument() diff --git a/backend/open_webui/utils/trace/setup.py b/backend/open_webui/utils/trace/setup.py new file mode 100644 index 000000000..a3374f3fa --- /dev/null +++ b/backend/open_webui/utils/trace/setup.py @@ -0,0 +1,24 @@ +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.sampling import ALWAYS_ON + +from open_webui.utils.trace.exporters import LazyBatchSpanProcessor +from open_webui.utils.trace.instrumentors import Instrumentor +from open_webui.env import OT_SERVICE_NAME, OT_HOST, OT_TOKEN + + +def setup(app): + trace.set_tracer_provider( + TracerProvider( + resource=Resource.create( + {SERVICE_NAME: OT_SERVICE_NAME, "token": OT_TOKEN} + ), + sampler=ALWAYS_ON, + ) + ) + # otlp + exporter = OTLPSpanExporter(endpoint=OT_HOST) + trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter)) + Instrumentor(app=app).instrument() diff --git a/backend/requirements.txt b/backend/requirements.txt index eb1ee6018..2f1728336 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -37,7 +37,7 @@ asgiref==3.8.1 # AI libraries openai anthropic -google-generativeai==0.7.2 +google-generativeai==0.8.4 tiktoken langchain==0.3.19 @@ -118,3 +118,17 @@ ldap3==2.9.1 ## Firecrawl firecrawl-py==1.12.0 + +## Trace +opentelemetry-api==1.30.0 +opentelemetry-sdk==1.30.0 +opentelemetry-exporter-otlp==1.30.0 +opentelemetry-instrumentation==0.51b0 +opentelemetry-instrumentation-fastapi==0.51b0 +opentelemetry-instrumentation-sqlalchemy==0.51b0 +opentelemetry-instrumentation-redis==0.51b0 +opentelemetry-instrumentation-requests==0.51b0 +opentelemetry-instrumentation-logging==0.51b0 +opentelemetry-instrumentation-httpx==0.51b0 +opentelemetry-instrumentation-aiohttp-client==0.51b0 +opentelemetry-instrumentation-loguru==0.51b0