From 144581a7df164b4521d4803c3d558f77a5cf42ac Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Tue, 9 Jul 2024 12:51:13 +0100 Subject: [PATCH] refac: get_sorted_pipelines() --- backend/main.py | 41 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/backend/main.py b/backend/main.py index e7210bc0a..e1890b85f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -764,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware) ################################## -def filter_pipeline(payload, user): - user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} - model_id = payload["model"] +def get_sorted_filters(model_id): filters = [ model for model in app.state.MODELS.values() @@ -782,6 +780,13 @@ def filter_pipeline(payload, user): ) ] sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + return sorted_filters + + +def filter_pipeline(payload, user): + user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} + model_id = payload["model"] + sorted_filters = get_sorted_filters(model_id) model = app.state.MODELS[model_id] @@ -814,19 +819,12 @@ def filter_pipeline(payload, user): print(f"Connection error: {e}") if r is not None: - try: - res = r.json() - except: - pass + res = r.json() if "detail" in res: raise Exception(r.status_code, res["detail"]) - else: - pass - - if "pipeline" not in app.state.MODELS[model_id]: - if "task" in payload: - del payload["task"] + if "pipeline" not in app.state.MODELS[model_id] and "task" in payload: + del payload["task"] return payload @@ -1061,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ) model = app.state.MODELS[model_id] - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + sorted_filters = get_sorted_filters(model_id) if "pipeline" in model: sorted_filters = [model] + sorted_filters