This commit is contained in:
Timothy J. Baek 2024-05-28 15:28:39 -07:00
parent 78e9bcf34c
commit ae0fe8f691

67
main.py
View File

@ -20,6 +20,7 @@ import importlib.util
import logging
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
@ -116,9 +117,6 @@ def on_startup():
on_startup()
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
for module in PIPELINE_MODULES.values():
@ -185,55 +183,82 @@ async def get_models():
"valves": pipeline["valves"] != None,
},
}
for pipeline in PIPELINES.values()
for pipeline in app.state.PIPELINES.values()
]
}
@app.get("/{pipeline_id}/valves")
async def get_valves(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
if pipeline_id not in app.state.PIPELINES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
pipeline = app.state.PIPELINES[pipeline_id]
if not pipeline.get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found",
)
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline = PIPELINE_MODULES[pipeline_id]
return pipeline.valves
pipeline_module = PIPELINE_MODULES[pipeline_id]
return pipeline_module.valves
@app.get("/{pipeline_id}/valves/spec")
async def get_valves_spec(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
if pipeline_id not in app.state.PIPELINES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
pipeline = app.state.PIPELINES[pipeline_id]
if not pipeline.get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found",
)
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline = PIPELINE_MODULES[pipeline_id]
return pipeline.valves.schema()
pipeline_module = PIPELINE_MODULES[pipeline_id]
return pipeline_module.valves.schema()
@app.post("/{pipeline_id}/valves/update")
async def update_valves(pipeline_id: str, form_data: dict):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
if pipeline_id not in app.state.PIPELINES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
pipeline = app.state.PIPELINES[pipeline_id]
if not pipeline.get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found",
)
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline = PIPELINE_MODULES[pipeline_id]
pipeline_module = PIPELINE_MODULES[pipeline_id]
try:
ValvesModel = pipeline.valves.__class__
ValvesModel = pipeline_module.valves.__class__
valves = ValvesModel(**form_data)
pipeline.valves = valves
pipeline_module.valves = valves
except Exception as e:
print(e)
raise HTTPException(
@ -241,7 +266,7 @@ async def update_valves(pipeline_id: str, form_data: dict):
detail=f"{str(e)}",
)
return pipeline.valves
return pipeline_module.valves
@app.post("/{pipeline_id}/filter")