From 19c340d3fb4d961cd0c960f7d34211a2fb0a2063 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 15 Feb 2025 22:25:18 -0800 Subject: [PATCH] refac: pipelines --- backend/open_webui/routers/pipelines.py | 112 +++++++++++++----------- backend/open_webui/utils/chat.py | 8 +- backend/open_webui/utils/middleware.py | 35 +++++--- 3 files changed, 84 insertions(+), 71 deletions(-) diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index 062663671..ad280b65c 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -9,6 +9,7 @@ from fastapi import ( status, APIRouter, ) +import aiohttp import os import logging import shutil @@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models): return sorted_filters -def process_pipeline_inlet_filter(request, payload, user, models): +async def process_pipeline_inlet_filter(request, payload, user, models): user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} model_id = payload["model"] - sorted_filters = get_sorted_filters(model_id, models) model = models[model_id] if "pipeline" in model: sorted_filters.append(model) - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] + async with aiohttp.ClientSession() as session: + for filter in sorted_filters: + urlIdx = filter.get("urlIdx") + if urlIdx is None: + continue url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - if key == "": + if not key: continue headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/inlet", - headers=headers, - json={ - "user": user, - "body": payload, - }, - ) + request_data = { + "user": user, + "body": payload, + } - r.raise_for_status() - payload = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - res = r.json() + try: + async with session.post( + f"{url}/{filter['id']}/filter/inlet", + headers=headers, + json=request_data, + ) as response: + response.raise_for_status() + payload = await response.json() + except aiohttp.ClientResponseError as e: + res = ( + await response.json() + if response.content_type == "application/json" + else {} + ) if "detail" in res: - raise Exception(r.status_code, res["detail"]) + raise Exception(response.status, res["detail"]) + except Exception as e: + print(f"Connection error: {e}") return payload -def process_pipeline_outlet_filter(request, payload, user, models): +async def process_pipeline_outlet_filter(request, payload, user, models): user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} model_id = payload["model"] - sorted_filters = get_sorted_filters(model_id, models) model = models[model_id] if "pipeline" in model: sorted_filters = [model] + sorted_filters - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] + async with aiohttp.ClientSession() as session: + for filter in sorted_filters: + urlIdx = filter.get("urlIdx") + if urlIdx is None: + continue url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - if key != "": - r = requests.post( + if not key: + continue + + headers = {"Authorization": f"Bearer {key}"} + request_data = { + "user": user, + "body": payload, + } + + try: + async with session.post( f"{url}/{filter['id']}/filter/outlet", - headers={"Authorization": f"Bearer {key}"}, - json={ - "user": user, - "body": payload, - }, - ) - - r.raise_for_status() - data = r.json() - payload = data - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: + headers=headers, + json=request_data, + ) as response: + response.raise_for_status() + payload = await response.json() + except aiohttp.ClientResponseError as e: try: - res = r.json() + res = ( + await response.json() + if "application/json" in response.content_type + else {} + ) if "detail" in res: - return Exception(r.status_code, res) + raise Exception(response.status, res) except Exception: pass - - else: - pass + except Exception as e: + print(f"Connection error: {e}") return payload diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 569bcad85..73e4264bf 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -186,12 +186,6 @@ async def generate_chat_completion( if model_id not in models: raise Exception("Model not found") - # Process the form_data through the pipeline - try: - form_data = process_pipeline_inlet_filter(request, form_data, user, models) - except Exception as e: - raise e - model = models[model_id] if getattr(request.state, "direct", False): @@ -308,7 +302,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): model = models[model_id] try: - data = process_pipeline_outlet_filter(request, data, user, models) + data = await process_pipeline_outlet_filter(request, data, user, models) except Exception as e: return Exception(f"Error: {e}") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 52f50b8a9..e708abacd 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -39,7 +39,10 @@ from open_webui.routers.tasks import ( ) from open_webui.routers.retrieval import process_web_search, SearchForm from open_webui.routers.images import image_generations, GenerateImageForm - +from open_webui.routers.pipelines import ( + process_pipeline_inlet_filter, + process_pipeline_outlet_filter, +) from open_webui.utils.webhook import post_webhook @@ -676,6 +679,25 @@ async def process_chat_payload(request, form_data, metadata, user, model): variables = form_data.pop("variables", None) + # Process the form_data through the pipeline + try: + form_data = await process_pipeline_inlet_filter( + request, form_data, user, models + ) + except Exception as e: + raise e + + try: + form_data, flags = await process_filter_functions( + request=request, + filter_ids=get_sorted_filter_ids(model), + filter_type="inlet", + form_data=form_data, + extra_params=extra_params, + ) + except Exception as e: + raise Exception(f"Error: {e}") + features = form_data.pop("features", None) if features: if "web_search" in features and features["web_search"]: @@ -698,17 +720,6 @@ async def process_chat_payload(request, form_data, metadata, user, model): form_data["messages"], ) - try: - form_data, flags = await process_filter_functions( - request=request, - filter_ids=get_sorted_filter_ids(model), - filter_type="inlet", - form_data=form_data, - extra_params=extra_params, - ) - except Exception as e: - raise Exception(f"Error: {e}") - tool_ids = form_data.pop("tool_ids", None) files = form_data.pop("files", None) # Remove files duplicates