diff --git a/cli.py b/cli.py deleted file mode 100644 index 9b834cb..0000000 --- a/cli.py +++ /dev/null @@ -1,70 +0,0 @@ -import logging -from pathlib import Path -from typing import Optional - -import typer - - -import subprocess -import time - - -def start_process(app: str, host: str, port: int, reload: bool = False): - # Start the FastAPI application - command = [ - "uvicorn", - app, - "--host", - host, - "--port", - str(port), - "--forwarded-allow-ips", - "*", - ] - - if reload: - command.append("--reload") - - process = subprocess.Popen(command) - return process - - -main = typer.Typer() - - -@main.command() -def serve( - host: str = "0.0.0.0", - port: int = 9099, -): - while True: - process = start_process("main:app", host, port, reload=False) - process.wait() - - if process.returncode == 42: - print("Restarting due to restart request") - time.sleep(2) # optional delay to prevent tight restart loops - else: - print("Normal exit, stopping the manager") - break - - -@main.command() -def dev( - host: str = "0.0.0.0", - port: int = 9099, -): - while True: - process = start_process("main:app", host, port, reload=True) - process.wait() - - if process.returncode == 42: - print("Restarting due to restart request") - time.sleep(2) # optional delay to prevent tight restart loops - else: - print("Normal exit, stopping the manager") - break - - -if __name__ == "__main__": - main() diff --git a/main.py b/main.py index bd79f31..cd99d6d 100644 --- a/main.py +++ b/main.py @@ -15,15 +15,16 @@ from utils.main import get_last_user_message, stream_message_template from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor from schemas import FilterForm, OpenAIChatCompletionForm +from urllib.parse import urlparse - -import sys +import aiohttp import os import importlib.util import logging import time import json import uuid +import sys #################################### @@ -37,6 +38,13 @@ try: except ImportError: print("dotenv not installed, skipping...") +API_KEY = os.getenv("API_KEY", "0p3n-w3bu!") + +PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines") + +if not os.path.exists(PIPELINES_DIR): + os.makedirs(PIPELINES_DIR) + PIPELINES = {} PIPELINE_MODULES = {} @@ -109,42 +117,60 @@ def get_all_pipelines(): return pipelines -def on_startup(): - def load_modules_from_directory(directory): - for filename in os.listdir(directory): - if filename.endswith(".py"): - module_name = filename[:-3] # Remove the .py extension - module_path = os.path.join(directory, filename) - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - yield module +async def load_module_from_path(module_name, module_path): + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, "Pipeline"): + return module.Pipeline() + return None - for loaded_module in load_modules_from_directory("./pipelines"): - # Do something with the loaded module - logging.info("Loaded:", loaded_module.__name__) - pipeline = loaded_module.Pipeline() - pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__ - - PIPELINE_MODULES[pipeline_id] = pipeline +async def load_modules_from_directory(directory): + global PIPELINE_MODULES + for filename in os.listdir(directory): + if filename.endswith(".py"): + module_name = filename[:-3] # Remove the .py extension + module_path = os.path.join(directory, filename) + pipeline = await load_module_from_path(module_name, module_path) + if pipeline: + pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name + PIPELINE_MODULES[pipeline_id] = pipeline + logging.info(f"Loaded module: {module_name}") + else: + logging.warning(f"No Pipeline class found in {module_name}") + global PIPELINES PIPELINES = get_all_pipelines() -on_startup() +async def on_startup(): + await load_modules_from_directory(PIPELINES_DIR) + + for module in PIPELINE_MODULES.values(): + if hasattr(module, "on_startup"): + await module.on_startup() + + +async def on_shutdown(): + for module in PIPELINE_MODULES.values(): + if hasattr(module, "on_shutdown"): + await module.on_shutdown() + + +async def reload(): + await on_shutdown() + # Clear existing pipelines + PIPELINE_MODULES.clear() + # Load pipelines afresh + await on_startup() @asynccontextmanager async def lifespan(app: FastAPI): - for module in PIPELINE_MODULES.values(): - if hasattr(module, "on_startup"): - await module.on_startup() + await on_startup() yield - - for module in PIPELINE_MODULES.values(): - if hasattr(module, "on_shutdown"): - await module.on_shutdown() + await on_shutdown() app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan) @@ -458,10 +484,118 @@ async def get_status(): return {"status": True} -@app.post("/v1/restart") -@app.post("/restart") -def restart_server(user: str = Depends(get_current_user)): +@app.get("/v1/pipelines") +@app.get("/pipelines") +async def list_pipelines(user: str = Depends(get_current_user)): + if user == API_KEY: + return {"data": list(app.state.PIPELINE_MODULES.keys())} + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) - print(user) - return True +class AddPipelineForm(BaseModel): + url: str + + +async def download_file(url: str, dest_folder: str): + filename = os.path.basename(urlparse(url).path) + if not filename.endswith(".py"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="URL must point to a Python file", + ) + + file_path = os.path.join(dest_folder, filename) + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status != 200: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to download file", + ) + with open(file_path, "wb") as f: + f.write(await response.read()) + + return file_path + + +@app.post("/v1/pipelines/add") +@app.post("/pipelines/add") +async def add_pipeline( + form_data: AddPipelineForm, user: str = Depends(get_current_user) +): + if user != API_KEY: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + try: + file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR) + await reload() + return { + "status": True, + "detail": f"Pipeline added successfully from {file_path}", + } + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +class DeletePipelineForm(BaseModel): + id: str + + +@app.delete("/v1/pipelines/delete") +@app.delete("/pipelines/delete") +async def delete_pipeline( + form_data: DeletePipelineForm, user: str = Depends(get_current_user) +): + if user != API_KEY: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + pipeline_id = form_data.id + pipeline_module = PIPELINE_MODULES.get(pipeline_id, None) + + if pipeline_module: + if hasattr(pipeline_module, "on_shutdown"): + await pipeline_module.on_shutdown() + pipeline_id = pipeline_module.__name__.split(".")[0] + + pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py") + if os.path.exists(pipeline_path): + os.remove(pipeline_path) + await reload() + return { + "status": True, + "detail": f"Pipeline {pipeline_id} deleted successfully", + } + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Pipeline {pipeline_id} not found", + ) + + +@app.post("/v1/reload") +@app.post("/reload") +async def reload_pipelines(user: str = Depends(get_current_user)): + if user == API_KEY: + await reload() + return {"message": "Pipelines reloaded successfully."} + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + )