refac: pipelines

This commit is contained in:
Timothy Jaeryang Baek 2025-02-15 22:25:18 -08:00
parent ea15d91e29
commit 19c340d3fb
3 changed files with 84 additions and 71 deletions

View File

@ -9,6 +9,7 @@ from fastapi import (
status,
APIRouter,
)
import aiohttp
import os
import logging
import shutil
@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
return sorted_filters
def process_pipeline_inlet_filter(request, payload, user, models):
async def process_pipeline_inlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters.append(model)
async with aiohttp.ClientSession() as session:
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
urlIdx = filter.get("urlIdx")
if urlIdx is None:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "":
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
request_data = {
"user": user,
"body": payload,
},
}
try:
async with session.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
res = (
await response.json()
if response.content_type == "application/json"
else {}
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
if "detail" in res:
raise Exception(r.status_code, res["detail"])
raise Exception(response.status, res["detail"])
except Exception as e:
print(f"Connection error: {e}")
return payload
def process_pipeline_outlet_filter(request, payload, user, models):
async def process_pipeline_outlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
async with aiohttp.ClientSession() as session:
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
urlIdx = filter.get("urlIdx")
if urlIdx is None:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"},
json={
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
request_data = {
"user": user,
"body": payload,
},
)
}
r.raise_for_status()
data = r.json()
payload = data
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
async with session.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
try:
res = (
await response.json()
if "application/json" in response.content_type
else {}
)
if "detail" in res:
return Exception(r.status_code, res)
raise Exception(response.status, res)
except Exception:
pass
else:
pass
except Exception as e:
print(f"Connection error: {e}")
return payload

View File

@ -186,12 +186,6 @@ async def generate_chat_completion(
if model_id not in models:
raise Exception("Model not found")
# Process the form_data through the pipeline
try:
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
except Exception as e:
raise e
model = models[model_id]
if getattr(request.state, "direct", False):
@ -308,7 +302,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
model = models[model_id]
try:
data = process_pipeline_outlet_filter(request, data, user, models)
data = await process_pipeline_outlet_filter(request, data, user, models)
except Exception as e:
return Exception(f"Error: {e}")

View File

@ -39,7 +39,10 @@ from open_webui.routers.tasks import (
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.routers.images import image_generations, GenerateImageForm
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
)
from open_webui.utils.webhook import post_webhook
@ -676,6 +679,25 @@ async def process_chat_payload(request, form_data, metadata, user, model):
variables = form_data.pop("variables", None)
# Process the form_data through the pipeline
try:
form_data = await process_pipeline_inlet_filter(
request, form_data, user, models
)
except Exception as e:
raise e
try:
form_data, flags = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")
features = form_data.pop("features", None)
if features:
if "web_search" in features and features["web_search"]:
@ -698,17 +720,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
form_data["messages"],
)
try:
form_data, flags = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_type="inlet",
form_data=form_data,
extra_params=extra_params,
)
except Exception as e:
raise Exception(f"Error: {e}")
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Remove files duplicates