diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py index 8f3a960a2..2f0e2b8bb 100644 --- a/backend/open_webui/retrieval/loaders/mistral.py +++ b/backend/open_webui/retrieval/loaders/mistral.py @@ -1,8 +1,12 @@ import requests +import aiohttp +import asyncio import logging import os import sys +import time from typing import List, Dict, Any +from contextlib import asynccontextmanager from langchain_core.documents import Document from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL @@ -14,18 +18,29 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) class MistralLoader: """ + Enhanced Mistral OCR loader with both sync and async support. Loads documents by processing them through the Mistral OCR API. """ BASE_API_URL = "https://api.mistral.ai/v1" - def __init__(self, api_key: str, file_path: str): + def __init__( + self, + api_key: str, + file_path: str, + timeout: int = 300, # 5 minutes default + max_retries: int = 3, + enable_debug_logging: bool = False + ): """ - Initializes the loader. + Initializes the loader with enhanced features. Args: api_key: Your Mistral API key. file_path: The local path to the PDF file to process. + timeout: Request timeout in seconds. + max_retries: Maximum number of retry attempts. + enable_debug_logging: Enable detailed debug logs. """ if not api_key: raise ValueError("API key cannot be empty.") @@ -34,7 +49,23 @@ class MistralLoader: self.api_key = api_key self.file_path = file_path - self.headers = {"Authorization": f"Bearer {self.api_key}"} + self.timeout = timeout + self.max_retries = max_retries + self.debug = enable_debug_logging + + # Pre-compute file info for performance + self.file_name = os.path.basename(file_path) + self.file_size = os.path.getsize(file_path) + + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "User-Agent": "OpenWebUI-MistralLoader/2.0" + } + + def _debug_log(self, message: str, *args) -> None: + """Conditional debug logging for performance.""" + if self.debug: + log.debug(message, *args) def _handle_response(self, response: requests.Response) -> Dict[str, Any]: """Checks response status and returns JSON content.""" @@ -54,24 +85,81 @@ class MistralLoader: log.error(f"JSON decode error: {json_err} - Response: {response.text}") raise # Re-raise after logging + async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]: + """Async version of response handling with better error info.""" + try: + response.raise_for_status() + + # Check content type + content_type = response.headers.get('content-type', '') + if 'application/json' not in content_type: + if response.status == 204: + return {} + text = await response.text() + raise ValueError(f"Unexpected content type: {content_type}, body: {text[:200]}...") + + return await response.json() + + except aiohttp.ClientResponseError as e: + error_text = await response.text() if response else "No response" + log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}") + raise + except aiohttp.ClientError as e: + log.error(f"Client error: {e}") + raise + except Exception as e: + log.error(f"Unexpected error processing response: {e}") + raise + + def _retry_request_sync(self, request_func, *args, **kwargs): + """Synchronous retry logic with exponential backoff.""" + for attempt in range(self.max_retries): + try: + return request_func(*args, **kwargs) + except (requests.exceptions.RequestException, Exception) as e: + if attempt == self.max_retries - 1: + raise + + wait_time = (2 ** attempt) + 0.5 + log.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...") + time.sleep(wait_time) + + async def _retry_request_async(self, request_func, *args, **kwargs): + """Async retry logic with exponential backoff.""" + for attempt in range(self.max_retries): + try: + return await request_func(*args, **kwargs) + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt == self.max_retries - 1: + raise + + wait_time = (2 ** attempt) + 0.5 + log.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...") + await asyncio.sleep(wait_time) + def _upload_file(self) -> str: - """Uploads the file to Mistral for OCR processing.""" + """Uploads the file to Mistral for OCR processing (sync version).""" log.info("Uploading file to Mistral API") url = f"{self.BASE_API_URL}/files" file_name = os.path.basename(self.file_path) - try: + def upload_request(): with open(self.file_path, "rb") as f: files = {"file": (file_name, f, "application/pdf")} data = {"purpose": "ocr"} - upload_headers = self.headers.copy() # Avoid modifying self.headers - response = requests.post( - url, headers=upload_headers, files=files, data=data + url, + headers=self.headers, + files=files, + data=data, + timeout=self.timeout ) - response_data = self._handle_response(response) + return self._handle_response(response) + + try: + response_data = self._retry_request_sync(upload_request) file_id = response_data.get("id") if not file_id: raise ValueError("File ID not found in upload response.") @@ -81,16 +169,63 @@ class MistralLoader: log.error(f"Failed to upload file: {e}") raise + async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: + """Async file upload with streaming for better memory efficiency.""" + url = f"{self.BASE_API_URL}/files" + + async def upload_request(): + # Create multipart writer for streaming upload + writer = aiohttp.MultipartWriter('form-data') + + # Add purpose field + purpose_part = writer.append('ocr') + purpose_part.set_content_disposition('form-data', name='purpose') + + # Add file part with streaming + file_part = writer.append_payload(aiohttp.streams.FilePayload( + self.file_path, + filename=self.file_name, + content_type='application/pdf' + )) + file_part.set_content_disposition('form-data', name='file', filename=self.file_name) + + self._debug_log(f"Uploading file: {self.file_name} ({self.file_size:,} bytes)") + + async with session.post( + url, + data=writer, + headers=self.headers, + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as response: + return await self._handle_response_async(response) + + response_data = await self._retry_request_async(upload_request) + + file_id = response_data.get("id") + if not file_id: + raise ValueError("File ID not found in upload response.") + + log.info(f"File uploaded successfully. File ID: {file_id}") + return file_id + def _get_signed_url(self, file_id: str) -> str: - """Retrieves a temporary signed URL for the uploaded file.""" + """Retrieves a temporary signed URL for the uploaded file (sync version).""" log.info(f"Getting signed URL for file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}/url" params = {"expiry": 1} signed_url_headers = {**self.headers, "Accept": "application/json"} + def url_request(): + response = requests.get( + url, + headers=signed_url_headers, + params=params, + timeout=self.timeout + ) + return self._handle_response(response) + try: - response = requests.get(url, headers=signed_url_headers, params=params) - response_data = self._handle_response(response) + response_data = self._retry_request_sync(url_request) signed_url = response_data.get("url") if not signed_url: raise ValueError("Signed URL not found in response.") @@ -100,8 +235,37 @@ class MistralLoader: log.error(f"Failed to get signed URL: {e}") raise + async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str: + """Async signed URL retrieval.""" + url = f"{self.BASE_API_URL}/files/{file_id}/url" + params = {"expiry": 1} + + headers = { + **self.headers, + "Accept": "application/json" + } + + async def url_request(): + self._debug_log(f"Getting signed URL for file ID: {file_id}") + async with session.get( + url, + headers=headers, + params=params, + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as response: + return await self._handle_response_async(response) + + response_data = await self._retry_request_async(url_request) + + signed_url = response_data.get("url") + if not signed_url: + raise ValueError("Signed URL not found in response.") + + self._debug_log("Signed URL received successfully") + return signed_url + def _process_ocr(self, signed_url: str) -> Dict[str, Any]: - """Sends the signed URL to the OCR endpoint for processing.""" + """Sends the signed URL to the OCR endpoint for processing (sync version).""" log.info("Processing OCR via Mistral API") url = f"{self.BASE_API_URL}/ocr" ocr_headers = { @@ -118,43 +282,179 @@ class MistralLoader: "include_image_base64": False, } + def ocr_request(): + response = requests.post( + url, + headers=ocr_headers, + json=payload, + timeout=self.timeout + ) + return self._handle_response(response) + try: - response = requests.post(url, headers=ocr_headers, json=payload) - ocr_response = self._handle_response(response) + ocr_response = self._retry_request_sync(ocr_request) log.info("OCR processing done.") - log.debug("OCR response: %s", ocr_response) + self._debug_log("OCR response: %s", ocr_response) return ocr_response except Exception as e: log.error(f"Failed during OCR processing: {e}") raise + async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]: + """Async OCR processing with timing metrics.""" + url = f"{self.BASE_API_URL}/ocr" + + headers = { + **self.headers, + "Content-Type": "application/json", + "Accept": "application/json", + } + + payload = { + "model": "mistral-ocr-latest", + "document": { + "type": "document_url", + "document_url": signed_url, + }, + "include_image_base64": False, + } + + async def ocr_request(): + log.info("Starting OCR processing via Mistral API") + start_time = time.time() + + async with session.post( + url, + json=payload, + headers=headers, + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as response: + ocr_response = await self._handle_response_async(response) + + processing_time = time.time() - start_time + log.info(f"OCR processing completed in {processing_time:.2f}s") + + return ocr_response + + return await self._retry_request_async(ocr_request) + def _delete_file(self, file_id: str) -> None: - """Deletes the file from Mistral storage.""" + """Deletes the file from Mistral storage (sync version).""" log.info(f"Deleting uploaded file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}" - # No specific Accept header needed, default or Authorization is usually sufficient try: - response = requests.delete(url, headers=self.headers) - delete_response = self._handle_response( - response - ) # Check status, ignore response body unless needed - log.info( - f"File deleted successfully: {delete_response}" - ) # Log the response if available + response = requests.delete(url, headers=self.headers, timeout=30) + delete_response = self._handle_response(response) + log.info(f"File deleted successfully: {delete_response}") except Exception as e: # Log error but don't necessarily halt execution if deletion fails log.error(f"Failed to delete file ID {file_id}: {e}") - # Depending on requirements, you might choose to raise the error here + + async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None: + """Async file deletion with error tolerance.""" + try: + async def delete_request(): + self._debug_log(f"Deleting file ID: {file_id}") + async with session.delete( + url=f"{self.BASE_API_URL}/files/{file_id}", + headers=self.headers, + timeout=aiohttp.ClientTimeout(total=30) # Shorter timeout for cleanup + ) as response: + return await self._handle_response_async(response) + + await self._retry_request_async(delete_request) + self._debug_log(f"File {file_id} deleted successfully") + + except Exception as e: + # Don't fail the entire process if cleanup fails + log.warning(f"Failed to delete file ID {file_id}: {e}") + + @asynccontextmanager + async def _get_session(self): + """Context manager for HTTP session with optimized settings.""" + connector = aiohttp.TCPConnector( + limit=10, # Total connection limit + limit_per_host=5, # Per-host connection limit + ttl_dns_cache=300, # DNS cache TTL + use_dns_cache=True, + keepalive_timeout=30, + enable_cleanup_closed=True + ) + + async with aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"} + ) as session: + yield session + + def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: + """Process OCR results into Document objects with enhanced metadata.""" + pages_data = ocr_response.get("pages") + if not pages_data: + log.warning("No pages found in OCR response.") + return [Document(page_content="No text content found", metadata={"error": "no_pages"})] + + documents = [] + total_pages = len(pages_data) + skipped_pages = 0 + + for page_data in pages_data: + page_content = page_data.get("markdown") + page_index = page_data.get("index") # API uses 0-based index + + if page_content is not None and page_index is not None: + # Clean up content efficiently + cleaned_content = page_content.strip() if isinstance(page_content, str) else str(page_content) + + if cleaned_content: # Only add non-empty pages + documents.append( + Document( + page_content=cleaned_content, + metadata={ + "page": page_index, # 0-based index from API + "page_label": page_index + 1, # 1-based label for convenience + "total_pages": total_pages, + "file_name": self.file_name, + "file_size": self.file_size, + "processing_engine": "mistral-ocr" + }, + ) + ) + else: + skipped_pages += 1 + self._debug_log(f"Skipping empty page {page_index}") + else: + skipped_pages += 1 + self._debug_log(f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}") + + if skipped_pages > 0: + log.info(f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages") + + if not documents: + # Case where pages existed but none had valid markdown/index + log.warning("OCR response contained pages, but none had valid content/index.") + return [ + Document( + page_content="No valid text content found in document", + metadata={"error": "no_valid_pages", "total_pages": total_pages} + ) + ] + + return documents def load(self) -> List[Document]: """ Executes the full OCR workflow: upload, get URL, process OCR, delete file. + Synchronous version for backward compatibility. Returns: A list of Document objects, one for each page processed. """ file_id = None + start_time = time.time() + try: # 1. Upload file file_id = self._upload_file() @@ -166,53 +466,21 @@ class MistralLoader: ocr_response = self._process_ocr(signed_url) # 4. Process results - pages_data = ocr_response.get("pages") - if not pages_data: - log.warning("No pages found in OCR response.") - return [Document(page_content="No text content found", metadata={})] - - documents = [] - total_pages = len(pages_data) - for page_data in pages_data: - page_content = page_data.get("markdown") - page_index = page_data.get("index") # API uses 0-based index - - if page_content is not None and page_index is not None: - documents.append( - Document( - page_content=page_content, - metadata={ - "page": page_index, # 0-based index from API - "page_label": page_index - + 1, # 1-based label for convenience - "total_pages": total_pages, - # Add other relevant metadata from page_data if available/needed - # e.g., page_data.get('width'), page_data.get('height') - }, - ) - ) - else: - log.warning( - f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" - ) - - if not documents: - # Case where pages existed but none had valid markdown/index - log.warning( - "OCR response contained pages, but none had valid content/index." - ) - return [ - Document( - page_content="No text content found in valid pages", metadata={} - ) - ] - + documents = self._process_results(ocr_response) + + total_time = time.time() - start_time + log.info(f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents") + return documents except Exception as e: - log.error(f"An error occurred during the loading process: {e}") - # Return an empty list or a specific error document on failure - return [Document(page_content=f"Error during processing: {e}", metadata={})] + total_time = time.time() - start_time + log.error(f"An error occurred during the loading process after {total_time:.2f}s: {e}") + # Return an error document on failure + return [Document( + page_content=f"Error during processing: {e}", + metadata={"error": "processing_failed", "file_name": self.file_name} + )] finally: # 5. Delete file (attempt even if prior steps failed after upload) if file_id: @@ -220,6 +488,88 @@ class MistralLoader: self._delete_file(file_id) except Exception as del_e: # Log deletion error, but don't overwrite original error if one occurred - log.error( - f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" - ) + log.error(f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}") + + async def load_async(self) -> List[Document]: + """ + Asynchronous OCR workflow execution with optimized performance. + + Returns: + A list of Document objects, one for each page processed. + """ + file_id = None + start_time = time.time() + + try: + async with self._get_session() as session: + # 1. Upload file with streaming + file_id = await self._upload_file_async(session) + + # 2. Get signed URL + signed_url = await self._get_signed_url_async(session, file_id) + + # 3. Process OCR + ocr_response = await self._process_ocr_async(session, signed_url) + + # 4. Process results + documents = self._process_results(ocr_response) + + total_time = time.time() - start_time + log.info(f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents") + + return documents + + except Exception as e: + total_time = time.time() - start_time + log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}") + return [Document( + page_content=f"Error during OCR processing: {e}", + metadata={"error": "processing_failed", "file_name": self.file_name} + )] + finally: + # 5. Cleanup - always attempt file deletion + if file_id: + try: + async with self._get_session() as session: + await self._delete_file_async(session, file_id) + except Exception as cleanup_error: + log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}") + + @staticmethod + async def load_multiple_async(loaders: List['MistralLoader']) -> List[List[Document]]: + """ + Process multiple files concurrently for maximum performance. + + Args: + loaders: List of MistralLoader instances + + Returns: + List of document lists, one for each loader + """ + if not loaders: + return [] + + log.info(f"Starting concurrent processing of {len(loaders)} files") + start_time = time.time() + + # Process all files concurrently + tasks = [loader.load_async() for loader in loaders] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle any exceptions in results + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + log.error(f"File {i} failed: {result}") + processed_results.append([Document( + page_content=f"Error processing file: {result}", + metadata={"error": "batch_processing_failed", "file_index": i} + )]) + else: + processed_results.append(result) + + total_time = time.time() - start_time + total_docs = sum(len(docs) for docs in processed_results) + log.info(f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents") + + return processed_results