This commit is contained in:
Timothy J. Baek 2024-05-29 22:18:17 -07:00
parent b440a11d32
commit 32bf2a841f

63
main.py
View File

@ -257,7 +257,20 @@ async def list_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY: if user == API_KEY:
return { return {
"data": [ "data": [
{"id": pipeline_id, "name": PIPELINE_NAMES[pipeline_id]} {
"id": pipeline_id,
"name": PIPELINE_NAMES[pipeline_id],
"type": (
PIPELINE_MODULES[pipeline_id].type
if hasattr(PIPELINE_MODULES[pipeline_id], "type")
else "pipe"
),
"valves": (
True
if hasattr(PIPELINE_MODULES[pipeline_id], "valves")
else False
),
}
for pipeline_id in list(PIPELINE_MODULES.keys()) for pipeline_id in list(PIPELINE_MODULES.keys())
] ]
} }
@ -378,80 +391,68 @@ async def reload_pipelines(user: str = Depends(get_current_user)):
@app.get("/v1/{pipeline_id}/valves") @app.get("/v1/{pipeline_id}/valves")
@app.get("/{pipeline_id}/valves") @app.get("/{pipeline_id}/valves")
async def get_valves(pipeline_id: str): async def get_valves(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES: if pipeline_id not in PIPELINE_MODULES:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found", detail=f"Pipeline {pipeline_id} not found",
) )
pipeline = app.state.PIPELINES[pipeline_id] pipeline = PIPELINE_MODULES[pipeline_id]
if not pipeline.get("valves", False):
if hasattr(pipeline, "valves") is False:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found", detail=f"Valves for {pipeline_id} not found",
) )
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline_module = PIPELINE_MODULES[pipeline_id] return pipeline.valves
return pipeline_module.valves
@app.get("/v1/{pipeline_id}/valves/spec") @app.get("/v1/{pipeline_id}/valves/spec")
@app.get("/{pipeline_id}/valves/spec") @app.get("/{pipeline_id}/valves/spec")
async def get_valves_spec(pipeline_id: str): async def get_valves_spec(pipeline_id: str):
if pipeline_id not in PIPELINE_MODULES:
if pipeline_id not in app.state.PIPELINES:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found", detail=f"Pipeline {pipeline_id} not found",
) )
pipeline = app.state.PIPELINES[pipeline_id] pipeline = PIPELINE_MODULES[pipeline_id]
if not pipeline.get("valves", False): if hasattr(pipeline, "valves") is False:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found", detail=f"Valves for {pipeline_id} not found",
) )
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline_module = PIPELINE_MODULES[pipeline_id] return pipeline.valves.schema()
return pipeline_module.valves.schema()
@app.post("/v1/{pipeline_id}/valves/update") @app.post("/v1/{pipeline_id}/valves/update")
@app.post("/{pipeline_id}/valves/update") @app.post("/{pipeline_id}/valves/update")
async def update_valves(pipeline_id: str, form_data: dict): async def update_valves(pipeline_id: str, form_data: dict):
if pipeline_id not in app.state.PIPELINES: if pipeline_id not in PIPELINE_MODULES:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found", detail=f"Pipeline {pipeline_id} not found",
) )
pipeline = app.state.PIPELINES[pipeline_id] pipeline = PIPELINE_MODULES[pipeline_id]
if not pipeline.get("valves", False):
if hasattr(pipeline, "valves") is False:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves for {pipeline_id} not found", detail=f"Valves for {pipeline_id} not found",
) )
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipeline_id = manifold_id
pipeline_module = PIPELINE_MODULES[pipeline_id]
try: try:
ValvesModel = pipeline_module.valves.__class__ ValvesModel = pipeline.valves.__class__
valves = ValvesModel(**form_data) valves = ValvesModel(**form_data)
pipeline_module.valves = valves pipeline.valves = valves
if hasattr(pipeline_module, "on_valves_update"): if hasattr(pipeline, "on_valves_update"):
await pipeline_module.on_valves_update() await pipeline.on_valves_update()
except Exception as e: except Exception as e:
print(e) print(e)
raise HTTPException( raise HTTPException(
@ -459,7 +460,7 @@ async def update_valves(pipeline_id: str, form_data: dict):
detail=f"{str(e)}", detail=f"{str(e)}",
) )
return pipeline_module.valves return pipeline.valves
@app.post("/v1/{pipeline_id}/filter") @app.post("/v1/{pipeline_id}/filter")