This commit is contained in:
Timothy J. Baek 2024-05-29 00:07:36 -07:00
parent 21bab00496
commit 6f9c8592c5
2 changed files with 166 additions and 102 deletions

70
cli.py
View File

@ -1,70 +0,0 @@
import logging
from pathlib import Path
from typing import Optional
import typer
import subprocess
import time
def start_process(app: str, host: str, port: int, reload: bool = False):
# Start the FastAPI application
command = [
"uvicorn",
app,
"--host",
host,
"--port",
str(port),
"--forwarded-allow-ips",
"*",
]
if reload:
command.append("--reload")
process = subprocess.Popen(command)
return process
main = typer.Typer()
@main.command()
def serve(
host: str = "0.0.0.0",
port: int = 9099,
):
while True:
process = start_process("main:app", host, port, reload=False)
process.wait()
if process.returncode == 42:
print("Restarting due to restart request")
time.sleep(2) # optional delay to prevent tight restart loops
else:
print("Normal exit, stopping the manager")
break
@main.command()
def dev(
host: str = "0.0.0.0",
port: int = 9099,
):
while True:
process = start_process("main:app", host, port, reload=True)
process.wait()
if process.returncode == 42:
print("Restarting due to restart request")
time.sleep(2) # optional delay to prevent tight restart loops
else:
print("Normal exit, stopping the manager")
break
if __name__ == "__main__":
main()

198
main.py
View File

@ -15,15 +15,16 @@ from utils.main import get_last_user_message, stream_message_template
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from schemas import FilterForm, OpenAIChatCompletionForm
from urllib.parse import urlparse
import sys
import aiohttp
import os
import importlib.util
import logging
import time
import json
import uuid
import sys
####################################
@ -37,6 +38,13 @@ try:
except ImportError:
print("dotenv not installed, skipping...")
API_KEY = os.getenv("API_KEY", "0p3n-w3bu!")
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
if not os.path.exists(PIPELINES_DIR):
os.makedirs(PIPELINES_DIR)
PIPELINES = {}
PIPELINE_MODULES = {}
@ -109,42 +117,60 @@ def get_all_pipelines():
return pipelines
def on_startup():
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
async def load_module_from_path(module_name, module_path):
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if hasattr(module, "Pipeline"):
return module.Pipeline()
return None
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__
PIPELINE_MODULES[pipeline_id] = pipeline
async def load_modules_from_directory(directory):
global PIPELINE_MODULES
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
pipeline = await load_module_from_path(module_name, module_path)
if pipeline:
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
PIPELINE_MODULES[pipeline_id] = pipeline
logging.info(f"Loaded module: {module_name}")
else:
logging.warning(f"No Pipeline class found in {module_name}")
global PIPELINES
PIPELINES = get_all_pipelines()
on_startup()
async def on_startup():
await load_modules_from_directory(PIPELINES_DIR)
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_startup"):
await module.on_startup()
async def on_shutdown():
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()
async def reload():
await on_shutdown()
# Clear existing pipelines
PIPELINE_MODULES.clear()
# Load pipelines afresh
await on_startup()
@asynccontextmanager
async def lifespan(app: FastAPI):
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_startup"):
await module.on_startup()
await on_startup()
yield
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()
await on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
@ -458,10 +484,118 @@ async def get_status():
return {"status": True}
@app.post("/v1/restart")
@app.post("/restart")
def restart_server(user: str = Depends(get_current_user)):
@app.get("/v1/pipelines")
@app.get("/pipelines")
async def list_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY:
return {"data": list(app.state.PIPELINE_MODULES.keys())}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
print(user)
return True
class AddPipelineForm(BaseModel):
url: str
async def download_file(url: str, dest_folder: str):
filename = os.path.basename(urlparse(url).path)
if not filename.endswith(".py"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="URL must point to a Python file",
)
file_path = os.path.join(dest_folder, filename)
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status != 200:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to download file",
)
with open(file_path, "wb") as f:
f.write(await response.read())
return file_path
@app.post("/v1/pipelines/add")
@app.post("/pipelines/add")
async def add_pipeline(
form_data: AddPipelineForm, user: str = Depends(get_current_user)
):
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
try:
file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR)
await reload()
return {
"status": True,
"detail": f"Pipeline added successfully from {file_path}",
}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
class DeletePipelineForm(BaseModel):
id: str
@app.delete("/v1/pipelines/delete")
@app.delete("/pipelines/delete")
async def delete_pipeline(
form_data: DeletePipelineForm, user: str = Depends(get_current_user)
):
if user != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
pipeline_id = form_data.id
pipeline_module = PIPELINE_MODULES.get(pipeline_id, None)
if pipeline_module:
if hasattr(pipeline_module, "on_shutdown"):
await pipeline_module.on_shutdown()
pipeline_id = pipeline_module.__name__.split(".")[0]
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py")
if os.path.exists(pipeline_path):
os.remove(pipeline_path)
await reload()
return {
"status": True,
"detail": f"Pipeline {pipeline_id} deleted successfully",
}
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {pipeline_id} not found",
)
@app.post("/v1/reload")
@app.post("/reload")
async def reload_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY:
await reload()
return {"message": "Pipelines reloaded successfully."}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)