This commit is contained in:
Timothy J. Baek 2024-05-28 15:28:39 -07:00
parent 78e9bcf34c
commit ae0fe8f691

67
main.py
View File

@ -20,6 +20,7 @@ import importlib.util
import logging import logging
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -116,9 +117,6 @@ def on_startup():
on_startup() on_startup()
from contextlib import asynccontextmanager
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
for module in PIPELINE_MODULES.values(): for module in PIPELINE_MODULES.values():
@ -185,55 +183,82 @@ async def get_models():
"valves": pipeline["valves"] != None, "valves": pipeline["valves"] != None,
}, },
} }
for pipeline in PIPELINES.values() for pipeline in app.state.PIPELINES.values()
] ]
} }
@app.get("/{pipeline_id}/valves") @app.get("/{pipeline_id}/valves")
async def get_valves(pipeline_id: str): async def get_valves(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ if pipeline_id not in app.state.PIPELINES:
pipeline_id raise HTTPException(
].get("valves", False): status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
pipeline = app.state.PIPELINES[pipeline_id]
if not pipeline.get("valves", False):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found", detail=f"Valves for {pipeline_id} not found",
) )
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline = PIPELINE_MODULES[pipeline_id] pipeline_module = PIPELINE_MODULES[pipeline_id]
return pipeline.valves return pipeline_module.valves
@app.get("/{pipeline_id}/valves/spec") @app.get("/{pipeline_id}/valves/spec")
async def get_valves_spec(pipeline_id: str): async def get_valves_spec(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id if pipeline_id not in app.state.PIPELINES:
].get("valves", False): raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
pipeline = app.state.PIPELINES[pipeline_id]
if not pipeline.get("valves", False):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found", detail=f"Valves for {pipeline_id} not found",
) )
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline = PIPELINE_MODULES[pipeline_id] pipeline_module = PIPELINE_MODULES[pipeline_id]
return pipeline.valves.schema() return pipeline_module.valves.schema()
@app.post("/{pipeline_id}/valves/update") @app.post("/{pipeline_id}/valves/update")
async def update_valves(pipeline_id: str, form_data: dict): async def update_valves(pipeline_id: str, form_data: dict):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id if pipeline_id not in app.state.PIPELINES:
].get("valves", False): raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
pipeline = app.state.PIPELINES[pipeline_id]
if not pipeline.get("valves", False):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found", detail=f"Valves for {pipeline_id} not found",
) )
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline = PIPELINE_MODULES[pipeline_id] pipeline_module = PIPELINE_MODULES[pipeline_id]
try: try:
ValvesModel = pipeline.valves.__class__ ValvesModel = pipeline_module.valves.__class__
valves = ValvesModel(**form_data) valves = ValvesModel(**form_data)
pipeline.valves = valves pipeline_module.valves = valves
except Exception as e: except Exception as e:
print(e) print(e)
raise HTTPException( raise HTTPException(
@ -241,7 +266,7 @@ async def update_valves(pipeline_id: str, form_data: dict):
detail=f"{str(e)}", detail=f"{str(e)}",
) )
return pipeline.valves return pipeline_module.valves
@app.post("/{pipeline_id}/filter") @app.post("/{pipeline_id}/filter")