feat: pipeline filter wildcard support

This commit is contained in:
Timothy J. Baek 2024-05-27 20:26:24 -07:00
parent 966f10e715
commit ec36493d61

View File

@ -249,21 +249,23 @@ class PipelineMiddleware(BaseHTTPMiddleware):
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
model_id = data["model"] model_id = data["model"]
valves = [ filters = [
model model
for model in app.state.MODELS.values() for model in app.state.MODELS.values()
if "pipeline" in model if "pipeline" in model
and model["pipeline"]["type"] == "filter" and model["pipeline"]["type"] == "filter"
and model_id and (
in [ model["pipeline"]["pipelines"] == ["*"]
target_model["id"] or any(
for target_model in model["pipeline"]["pipelines"] model_id == target_model["id"]
] for target_model in model["pipeline"]["pipelines"]
)
)
] ]
sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"]) sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
user = None user = None
if len(sorted_valves) > 0: if len(sorted_filters) > 0:
try: try:
user = get_current_user( user = get_current_user(
get_http_authorization_cred( get_http_authorization_cred(
@ -274,10 +276,12 @@ class PipelineMiddleware(BaseHTTPMiddleware):
except: except:
pass pass
for valve in sorted_valves: print(sorted_filters)
for filter in sorted_filters:
try: try:
urlIdx = valve["urlIdx"] urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
@ -289,7 +293,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
headers=headers, headers=headers,
json={ json={
"user": user, "user": user,
"model": valve["id"], "model": filter["id"],
"body": data, "body": data,
}, },
) )
@ -298,7 +302,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
data = r.json() data = r.json()
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
log.error(f"Connection error: {e}") print(f"Connection error: {e}")
pass pass
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")