This commit is contained in:
Timothy J. Baek 2024-05-28 15:45:46 -07:00
parent 88f3d59fcb
commit cabd666152

58
main.py
View File

@ -42,25 +42,10 @@ PIPELINES = {}
PIPELINE_MODULES = {}
def on_startup():
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__
PIPELINE_MODULES[pipeline_id] = pipeline
def get_all_pipelines():
pipelines = {}
for pipeline_id in PIPELINE_MODULES.keys():
pipeline = PIPELINE_MODULES[pipeline_id]
if hasattr(pipeline, "type"):
if pipeline.type == "manifold":
@ -73,7 +58,7 @@ def on_startup():
f"{pipeline.name}{manifold_pipeline_name}"
)
PIPELINES[manifold_pipeline_id] = {
pipelines[manifold_pipeline_id] = {
"module": pipeline_id,
"type": pipeline.type if hasattr(pipeline, "type") else "pipe",
"id": manifold_pipeline_id,
@ -83,7 +68,7 @@ def on_startup():
),
}
if pipeline.type == "filter":
PIPELINES[pipeline_id] = {
pipelines[pipeline_id] = {
"module": pipeline_id,
"type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
"id": pipeline_id,
@ -105,7 +90,7 @@ def on_startup():
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
}
else:
PIPELINES[pipeline_id] = {
pipelines[pipeline_id] = {
"module": pipeline_id,
"type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
"id": pipeline_id,
@ -113,6 +98,31 @@ def on_startup():
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
}
return pipelines
def on_startup():
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__
PIPELINE_MODULES[pipeline_id] = pipeline
PIPELINES = get_all_pipelines()
on_startup()
@ -162,6 +172,8 @@ async def get_models():
"""
Returns the available pipelines
"""
app.state.PIPELINES = get_all_pipelines()
return {
"data": [
{
@ -190,6 +202,8 @@ async def get_models():
@app.get("/{pipeline_id}/valves")
async def get_valves(pipeline_id: str):
app.state.PIPELINES = get_all_pipelines()
if pipeline_id not in app.state.PIPELINES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,