This commit is contained in:
Timothy J. Baek 2024-05-29 00:07:36 -07:00
parent 21bab00496
commit 6f9c8592c5
2 changed files with 166 additions and 102 deletions

70
cli.py
View File

@ -1,70 +0,0 @@
import logging
from pathlib import Path
from typing import Optional
import typer
import subprocess
import time
def start_process(app: str, host: str, port: int, reload: bool = False):
# Start the FastAPI application
command = [
"uvicorn",
app,
"--host",
host,
"--port",
str(port),
"--forwarded-allow-ips",
"*",
]
if reload:
command.append("--reload")
process = subprocess.Popen(command)
return process
main = typer.Typer()
@main.command()
def serve(
host: str = "0.0.0.0",
port: int = 9099,
):
while True:
process = start_process("main:app", host, port, reload=False)
process.wait()
if process.returncode == 42:
print("Restarting due to restart request")
time.sleep(2) # optional delay to prevent tight restart loops
else:
print("Normal exit, stopping the manager")
break
@main.command()
def dev(
host: str = "0.0.0.0",
port: int = 9099,
):
while True:
process = start_process("main:app", host, port, reload=True)
process.wait()
if process.returncode == 42:
print("Restarting due to restart request")
time.sleep(2) # optional delay to prevent tight restart loops
else:
print("Normal exit, stopping the manager")
break
if __name__ == "__main__":
main()

198
main.py
View File

@ -15,15 +15,16 @@ from utils.main import get_last_user_message, stream_message_template
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from schemas import FilterForm, OpenAIChatCompletionForm from schemas import FilterForm, OpenAIChatCompletionForm
from urllib.parse import urlparse
import aiohttp
import sys
import os import os
import importlib.util import importlib.util
import logging import logging
import time import time
import json import json
import uuid import uuid
import sys
#################################### ####################################
@ -37,6 +38,13 @@ try:
except ImportError: except ImportError:
print("dotenv not installed, skipping...") print("dotenv not installed, skipping...")
API_KEY = os.getenv("API_KEY", "0p3n-w3bu!")
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
if not os.path.exists(PIPELINES_DIR):
os.makedirs(PIPELINES_DIR)
PIPELINES = {} PIPELINES = {}
PIPELINE_MODULES = {} PIPELINE_MODULES = {}
@ -109,42 +117,60 @@ def get_all_pipelines():
return pipelines return pipelines
def on_startup(): async def load_module_from_path(module_name, module_path):
def load_modules_from_directory(directory): spec = importlib.util.spec_from_file_location(module_name, module_path)
for filename in os.listdir(directory): module = importlib.util.module_from_spec(spec)
if filename.endswith(".py"): spec.loader.exec_module(module)
module_name = filename[:-3] # Remove the .py extension if hasattr(module, "Pipeline"):
module_path = os.path.join(directory, filename) return module.Pipeline()
spec = importlib.util.spec_from_file_location(module_name, module_path) return None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline() async def load_modules_from_directory(directory):
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__ global PIPELINE_MODULES
for filename in os.listdir(directory):
PIPELINE_MODULES[pipeline_id] = pipeline if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
pipeline = await load_module_from_path(module_name, module_path)
if pipeline:
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
PIPELINE_MODULES[pipeline_id] = pipeline
logging.info(f"Loaded module: {module_name}")
else:
logging.warning(f"No Pipeline class found in {module_name}")
global PIPELINES
PIPELINES = get_all_pipelines() PIPELINES = get_all_pipelines()
on_startup() async def on_startup():
await load_modules_from_directory(PIPELINES_DIR)
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_startup"):
await module.on_startup()
async def on_shutdown():
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()
async def reload():
await on_shutdown()
# Clear existing pipelines
PIPELINE_MODULES.clear()
# Load pipelines afresh
await on_startup()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
for module in PIPELINE_MODULES.values(): await on_startup()
if hasattr(module, "on_startup"):
await module.on_startup()
yield yield
await on_shutdown()
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan) app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
@ -458,10 +484,118 @@ async def get_status():
return {"status": True} return {"status": True}
@app.post("/v1/restart") @app.get("/v1/pipelines")
@app.post("/restart") @app.get("/pipelines")
def restart_server(user: str = Depends(get_current_user)): async def list_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY:
return {"data": list(app.state.PIPELINE_MODULES.keys())}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
print(user)
return True class AddPipelineForm(BaseModel):
url: str
async def download_file(url: str, dest_folder: str):
filename = os.path.basename(urlparse(url).path)
if not filename.endswith(".py"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="URL must point to a Python file",
)
file_path = os.path.join(dest_folder, filename)
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status != 200:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to download file",
)
with open(file_path, "wb") as f:
f.write(await response.read())
return file_path
@app.post("/v1/pipelines/add")
@app.post("/pipelines/add")
async def add_pipeline(
form_data: AddPipelineForm, user: str = Depends(get_current_user)
):
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
try:
file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR)
await reload()
return {
"status": True,
"detail": f"Pipeline added successfully from {file_path}",
}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
class DeletePipelineForm(BaseModel):
id: str
@app.delete("/v1/pipelines/delete")
@app.delete("/pipelines/delete")
async def delete_pipeline(
form_data: DeletePipelineForm, user: str = Depends(get_current_user)
):
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
pipeline_id = form_data.id
pipeline_module = PIPELINE_MODULES.get(pipeline_id, None)
if pipeline_module:
if hasattr(pipeline_module, "on_shutdown"):
await pipeline_module.on_shutdown()
pipeline_id = pipeline_module.__name__.split(".")[0]
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py")
if os.path.exists(pipeline_path):
os.remove(pipeline_path)
await reload()
return {
"status": True,
"detail": f"Pipeline {pipeline_id} deleted successfully",
}
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
@app.post("/v1/reload")
@app.post("/reload")
async def reload_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY:
await reload()
return {"message": "Pipelines reloaded successfully."}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)