refac: pipelines

This commit is contained in:
Timothy Jaeryang Baek 2025-02-15 22:25:18 -08:00
parent ea15d91e29
commit 19c340d3fb
3 changed files with 84 additions and 71 deletions

View File

@ -9,6 +9,7 @@ from fastapi import (
status, status,
APIRouter, APIRouter,
) )
import aiohttp
import os import os
import logging import logging
import shutil import shutil
@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
return sorted_filters return sorted_filters
def process_pipeline_inlet_filter(request, payload, user, models): async def process_pipeline_inlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"] model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models) sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id] model = models[model_id]
if "pipeline" in model: if "pipeline" in model:
sorted_filters.append(model) sorted_filters.append(model)
async with aiohttp.ClientSession() as session:
for filter in sorted_filters: for filter in sorted_filters:
r = None urlIdx = filter.get("urlIdx")
try: if urlIdx is None:
urlIdx = filter["urlIdx"] continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "": if not key:
continue continue
headers = {"Authorization": f"Bearer {key}"} headers = {"Authorization": f"Bearer {key}"}
r = requests.post( request_data = {
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user, "user": user,
"body": payload, "body": payload,
}, }
try:
async with session.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
res = (
await response.json()
if response.content_type == "application/json"
else {}
) )
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
if "detail" in res: if "detail" in res:
raise Exception(r.status_code, res["detail"]) raise Exception(response.status, res["detail"])
except Exception as e:
print(f"Connection error: {e}")
return payload return payload
def process_pipeline_outlet_filter(request, payload, user, models): async def process_pipeline_outlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"] model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models) sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id] model = models[model_id]
if "pipeline" in model: if "pipeline" in model:
sorted_filters = [model] + sorted_filters sorted_filters = [model] + sorted_filters
async with aiohttp.ClientSession() as session:
for filter in sorted_filters: for filter in sorted_filters:
r = None urlIdx = filter.get("urlIdx")
try: if urlIdx is None:
urlIdx = filter["urlIdx"] continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "": if not key:
r = requests.post( continue
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"}, headers = {"Authorization": f"Bearer {key}"}
json={ request_data = {
"user": user, "user": user,
"body": payload, "body": payload,
}, }
)
r.raise_for_status()
data = r.json()
payload = data
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try: try:
res = r.json() async with session.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
try:
res = (
await response.json()
if "application/json" in response.content_type
else {}
)
if "detail" in res: if "detail" in res:
return Exception(r.status_code, res) raise Exception(response.status, res)
except Exception: except Exception:
pass pass
except Exception as e:
else: print(f"Connection error: {e}")
pass
return payload return payload

View File

@ -186,12 +186,6 @@ async def generate_chat_completion(
if model_id not in models: if model_id not in models:
raise Exception("Model not found") raise Exception("Model not found")
# Process the form_data through the pipeline
try:
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
except Exception as e:
raise e
model = models[model_id] model = models[model_id]
if getattr(request.state, "direct", False): if getattr(request.state, "direct", False):
@ -308,7 +302,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
model = models[model_id] model = models[model_id]
try: try:
data = process_pipeline_outlet_filter(request, data, user, models) data = await process_pipeline_outlet_filter(request, data, user, models)
except Exception as e: except Exception as e:
return Exception(f"Error: {e}") return Exception(f"Error: {e}")

View File

@ -39,7 +39,10 @@ from open_webui.routers.tasks import (
) )
from open_webui.routers.retrieval import process_web_search, SearchForm from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.routers.images import image_generations, GenerateImageForm from open_webui.routers.images import image_generations, GenerateImageForm
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
)
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
@ -676,6 +679,25 @@ async def process_chat_payload(request, form_data, metadata, user, model):
variables = form_data.pop("variables", None) variables = form_data.pop("variables", None)
# Process the form_data through the pipeline
try:
form_data = await process_pipeline_inlet_filter(
request, form_data, user, models
)
except Exception as e:
raise e
try:
form_data, flags = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")
features = form_data.pop("features", None) features = form_data.pop("features", None)
if features: if features:
if "web_search" in features and features["web_search"]: if "web_search" in features and features["web_search"]:
@ -698,17 +720,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
form_data["messages"], form_data["messages"],
) )
try:
form_data, flags = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")
tool_ids = form_data.pop("tool_ids", None) tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None) files = form_data.pop("files", None)
# Remove files duplicates # Remove files duplicates