diff --git a/main.py b/main.py index 1afe21c..c13a83a 100644 --- a/main.py +++ b/main.py @@ -42,25 +42,10 @@ PIPELINES = {} PIPELINE_MODULES = {} -def on_startup(): - def load_modules_from_directory(directory): - for filename in os.listdir(directory): - if filename.endswith(".py"): - module_name = filename[:-3] # Remove the .py extension - module_path = os.path.join(directory, filename) - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - yield module - - for loaded_module in load_modules_from_directory("./pipelines"): - # Do something with the loaded module - logging.info("Loaded:", loaded_module.__name__) - - pipeline = loaded_module.Pipeline() - pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__ - - PIPELINE_MODULES[pipeline_id] = pipeline +def get_all_pipelines(): + pipelines = {} + for pipeline_id in PIPELINE_MODULES.keys(): + pipeline = PIPELINE_MODULES[pipeline_id] if hasattr(pipeline, "type"): if pipeline.type == "manifold": @@ -73,7 +58,7 @@ def on_startup(): f"{pipeline.name}{manifold_pipeline_name}" ) - PIPELINES[manifold_pipeline_id] = { + pipelines[manifold_pipeline_id] = { "module": pipeline_id, "type": pipeline.type if hasattr(pipeline, "type") else "pipe", "id": manifold_pipeline_id, @@ -83,7 +68,7 @@ def on_startup(): ), } if pipeline.type == "filter": - PIPELINES[pipeline_id] = { + pipelines[pipeline_id] = { "module": pipeline_id, "type": (pipeline.type if hasattr(pipeline, "type") else "pipe"), "id": pipeline_id, @@ -105,7 +90,7 @@ def on_startup(): "valves": pipeline.valves if hasattr(pipeline, "valves") else None, } else: - PIPELINES[pipeline_id] = { + pipelines[pipeline_id] = { "module": pipeline_id, "type": (pipeline.type if hasattr(pipeline, "type") else "pipe"), "id": pipeline_id, @@ -113,6 +98,31 @@ def on_startup(): "valves": pipeline.valves if hasattr(pipeline, "valves") else None, } + return pipelines + + +def on_startup(): + def load_modules_from_directory(directory): + for filename in os.listdir(directory): + if filename.endswith(".py"): + module_name = filename[:-3] # Remove the .py extension + module_path = os.path.join(directory, filename) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + yield module + + for loaded_module in load_modules_from_directory("./pipelines"): + # Do something with the loaded module + logging.info("Loaded:", loaded_module.__name__) + + pipeline = loaded_module.Pipeline() + pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__ + + PIPELINE_MODULES[pipeline_id] = pipeline + + PIPELINES = get_all_pipelines() + on_startup() @@ -162,6 +172,8 @@ async def get_models(): """ Returns the available pipelines """ + app.state.PIPELINES = get_all_pipelines() + return { "data": [ { @@ -190,6 +202,8 @@ async def get_models(): @app.get("/{pipeline_id}/valves") async def get_valves(pipeline_id: str): + app.state.PIPELINES = get_all_pipelines() + if pipeline_id not in app.state.PIPELINES: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND,