From ae0fe8f6914664203650da76a7effec1eb171063 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 28 May 2024 15:28:39 -0700 Subject: [PATCH] fix --- main.py | 67 +++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 36ced30..826260f 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,7 @@ import importlib.util import logging +from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor @@ -116,9 +117,6 @@ def on_startup(): on_startup() -from contextlib import asynccontextmanager - - @asynccontextmanager async def lifespan(app: FastAPI): for module in PIPELINE_MODULES.values(): @@ -185,55 +183,82 @@ async def get_models(): "valves": pipeline["valves"] != None, }, } - for pipeline in PIPELINES.values() + for pipeline in app.state.PIPELINES.values() ] } @app.get("/{pipeline_id}/valves") async def get_valves(pipeline_id: str): - if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ - pipeline_id - ].get("valves", False): + if pipeline_id not in app.state.PIPELINES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Pipeline {pipeline_id} not found", + ) + + pipeline = app.state.PIPELINES[pipeline_id] + if not pipeline.get("valves", False): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Valves for {pipeline_id} not found", ) + if pipeline["type"] == "manifold": + manifold_id, pipeline_id = pipeline_id.split(".", 1) + pipeline_id = manifold_id - pipeline = PIPELINE_MODULES[pipeline_id] - return pipeline.valves + pipeline_module = PIPELINE_MODULES[pipeline_id] + return pipeline_module.valves @app.get("/{pipeline_id}/valves/spec") async def get_valves_spec(pipeline_id: str): - if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ - pipeline_id - ].get("valves", False): + + if pipeline_id not in app.state.PIPELINES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Pipeline {pipeline_id} not found", + ) + + pipeline = app.state.PIPELINES[pipeline_id] + + if not pipeline.get("valves", False): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Valves for {pipeline_id} not found", ) + if pipeline["type"] == "manifold": + manifold_id, pipeline_id = pipeline_id.split(".", 1) + pipeline_id = manifold_id - pipeline = PIPELINE_MODULES[pipeline_id] - return pipeline.valves.schema() + pipeline_module = PIPELINE_MODULES[pipeline_id] + return pipeline_module.valves.schema() @app.post("/{pipeline_id}/valves/update") async def update_valves(pipeline_id: str, form_data: dict): - if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ - pipeline_id - ].get("valves", False): + + if pipeline_id not in app.state.PIPELINES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Pipeline {pipeline_id} not found", + ) + + pipeline = app.state.PIPELINES[pipeline_id] + if not pipeline.get("valves", False): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Valves for {pipeline_id} not found", ) + if pipeline["type"] == "manifold": + manifold_id, pipeline_id = pipeline_id.split(".", 1) + pipeline_id = manifold_id - pipeline = PIPELINE_MODULES[pipeline_id] + pipeline_module = PIPELINE_MODULES[pipeline_id] try: - ValvesModel = pipeline.valves.__class__ + ValvesModel = pipeline_module.valves.__class__ valves = ValvesModel(**form_data) - pipeline.valves = valves + pipeline_module.valves = valves except Exception as e: print(e) raise HTTPException( @@ -241,7 +266,7 @@ async def update_valves(pipeline_id: str, form_data: dict): detail=f"{str(e)}", ) - return pipeline.valves + return pipeline_module.valves @app.post("/{pipeline_id}/filter")