mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
refac: pipelines
This commit is contained in:
parent
ea15d91e29
commit
19c340d3fb
@ -9,6 +9,7 @@ from fastapi import (
|
|||||||
status,
|
status,
|
||||||
APIRouter,
|
APIRouter,
|
||||||
)
|
)
|
||||||
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
|
|||||||
return sorted_filters
|
return sorted_filters
|
||||||
|
|
||||||
|
|
||||||
def process_pipeline_inlet_filter(request, payload, user, models):
|
async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
model_id = payload["model"]
|
model_id = payload["model"]
|
||||||
|
|
||||||
sorted_filters = get_sorted_filters(model_id, models)
|
sorted_filters = get_sorted_filters(model_id, models)
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
if "pipeline" in model:
|
if "pipeline" in model:
|
||||||
sorted_filters.append(model)
|
sorted_filters.append(model)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
for filter in sorted_filters:
|
for filter in sorted_filters:
|
||||||
r = None
|
urlIdx = filter.get("urlIdx")
|
||||||
try:
|
if urlIdx is None:
|
||||||
urlIdx = filter["urlIdx"]
|
continue
|
||||||
|
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
if key == "":
|
if not key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {key}"}
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
r = requests.post(
|
request_data = {
|
||||||
f"{url}/{filter['id']}/filter/inlet",
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"user": user,
|
"user": user,
|
||||||
"body": payload,
|
"body": payload,
|
||||||
},
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{url}/{filter['id']}/filter/inlet",
|
||||||
|
headers=headers,
|
||||||
|
json=request_data,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = await response.json()
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
res = (
|
||||||
|
await response.json()
|
||||||
|
if response.content_type == "application/json"
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
payload = r.json()
|
|
||||||
except Exception as e:
|
|
||||||
# Handle connection error here
|
|
||||||
print(f"Connection error: {e}")
|
|
||||||
|
|
||||||
if r is not None:
|
|
||||||
res = r.json()
|
|
||||||
if "detail" in res:
|
if "detail" in res:
|
||||||
raise Exception(r.status_code, res["detail"])
|
raise Exception(response.status, res["detail"])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Connection error: {e}")
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def process_pipeline_outlet_filter(request, payload, user, models):
|
async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
model_id = payload["model"]
|
model_id = payload["model"]
|
||||||
|
|
||||||
sorted_filters = get_sorted_filters(model_id, models)
|
sorted_filters = get_sorted_filters(model_id, models)
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
if "pipeline" in model:
|
if "pipeline" in model:
|
||||||
sorted_filters = [model] + sorted_filters
|
sorted_filters = [model] + sorted_filters
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
for filter in sorted_filters:
|
for filter in sorted_filters:
|
||||||
r = None
|
urlIdx = filter.get("urlIdx")
|
||||||
try:
|
if urlIdx is None:
|
||||||
urlIdx = filter["urlIdx"]
|
continue
|
||||||
|
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
if key != "":
|
if not key:
|
||||||
r = requests.post(
|
continue
|
||||||
f"{url}/{filter['id']}/filter/outlet",
|
|
||||||
headers={"Authorization": f"Bearer {key}"},
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
json={
|
request_data = {
|
||||||
"user": user,
|
"user": user,
|
||||||
"body": payload,
|
"body": payload,
|
||||||
},
|
}
|
||||||
)
|
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
payload = data
|
|
||||||
except Exception as e:
|
|
||||||
# Handle connection error here
|
|
||||||
print(f"Connection error: {e}")
|
|
||||||
|
|
||||||
if r is not None:
|
|
||||||
try:
|
try:
|
||||||
res = r.json()
|
async with session.post(
|
||||||
|
f"{url}/{filter['id']}/filter/outlet",
|
||||||
|
headers=headers,
|
||||||
|
json=request_data,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = await response.json()
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
try:
|
||||||
|
res = (
|
||||||
|
await response.json()
|
||||||
|
if "application/json" in response.content_type
|
||||||
|
else {}
|
||||||
|
)
|
||||||
if "detail" in res:
|
if "detail" in res:
|
||||||
return Exception(r.status_code, res)
|
raise Exception(response.status, res)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
except Exception as e:
|
||||||
else:
|
print(f"Connection error: {e}")
|
||||||
pass
|
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
@ -186,12 +186,6 @@ async def generate_chat_completion(
|
|||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
||||||
# Process the form_data through the pipeline
|
|
||||||
try:
|
|
||||||
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
if getattr(request.state, "direct", False):
|
if getattr(request.state, "direct", False):
|
||||||
@ -308,7 +302,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = process_pipeline_outlet_filter(request, data, user, models)
|
data = await process_pipeline_outlet_filter(request, data, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Exception(f"Error: {e}")
|
return Exception(f"Error: {e}")
|
||||||
|
|
||||||
|
@ -39,7 +39,10 @@ from open_webui.routers.tasks import (
|
|||||||
)
|
)
|
||||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||||
from open_webui.routers.images import image_generations, GenerateImageForm
|
from open_webui.routers.images import image_generations, GenerateImageForm
|
||||||
|
from open_webui.routers.pipelines import (
|
||||||
|
process_pipeline_inlet_filter,
|
||||||
|
process_pipeline_outlet_filter,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
|
|
||||||
@ -676,6 +679,25 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
|
|
||||||
variables = form_data.pop("variables", None)
|
variables = form_data.pop("variables", None)
|
||||||
|
|
||||||
|
# Process the form_data through the pipeline
|
||||||
|
try:
|
||||||
|
form_data = await process_pipeline_inlet_filter(
|
||||||
|
request, form_data, user, models
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
try:
|
||||||
|
form_data, flags = await process_filter_functions(
|
||||||
|
request=request,
|
||||||
|
filter_ids=get_sorted_filter_ids(model),
|
||||||
|
filter_type="inlet",
|
||||||
|
form_data=form_data,
|
||||||
|
extra_params=extra_params,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error: {e}")
|
||||||
|
|
||||||
features = form_data.pop("features", None)
|
features = form_data.pop("features", None)
|
||||||
if features:
|
if features:
|
||||||
if "web_search" in features and features["web_search"]:
|
if "web_search" in features and features["web_search"]:
|
||||||
@ -698,17 +720,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
|||||||
form_data["messages"],
|
form_data["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
form_data, flags = await process_filter_functions(
|
|
||||||
request=request,
|
|
||||||
filter_ids=get_sorted_filter_ids(model),
|
|
||||||
filter_type="inlet",
|
|
||||||
form_data=form_data,
|
|
||||||
extra_params=extra_params,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Error: {e}")
|
|
||||||
|
|
||||||
tool_ids = form_data.pop("tool_ids", None)
|
tool_ids = form_data.pop("tool_ids", None)
|
||||||
files = form_data.pop("files", None)
|
files = form_data.pop("files", None)
|
||||||
# Remove files duplicates
|
# Remove files duplicates
|
||||||
|
Loading…
Reference in New Issue
Block a user