feat: dynamic loading endpoints

This commit is contained in:
Timothy J. Baek 2024-05-29 00:34:31 -07:00
parent 6f9c8592c5
commit e34e8c4fc0
2 changed files with 58 additions and 11 deletions

34
main.py
View File

@ -10,7 +10,7 @@ from typing import List, Union, Generator, Iterator
from utils.auth import bearer_security, get_current_user from utils.auth import bearer_security, get_current_user
from utils.main import get_last_user_message, stream_message_template from utils.main import get_last_user_message, stream_message_template
from utils.misc import convert_to_raw_url
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -38,7 +38,7 @@ try:
except ImportError: except ImportError:
print("dotenv not installed, skipping...") print("dotenv not installed, skipping...")
API_KEY = os.getenv("API_KEY", "0p3n-w3bu!") API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!")
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines") PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
@ -48,6 +48,7 @@ if not os.path.exists(PIPELINES_DIR):
PIPELINES = {} PIPELINES = {}
PIPELINE_MODULES = {} PIPELINE_MODULES = {}
PIPELINE_NAMES = {}
def get_all_pipelines(): def get_all_pipelines():
@ -121,6 +122,8 @@ async def load_module_from_path(module_name, module_path):
spec = importlib.util.spec_from_file_location(module_name, module_path) spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipeline"): if hasattr(module, "Pipeline"):
return module.Pipeline() return module.Pipeline()
return None return None
@ -128,6 +131,8 @@ async def load_module_from_path(module_name, module_path):
async def load_modules_from_directory(directory): async def load_modules_from_directory(directory):
global PIPELINE_MODULES global PIPELINE_MODULES
global PIPELINE_NAMES
for filename in os.listdir(directory): for filename in os.listdir(directory):
if filename.endswith(".py"): if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension module_name = filename[:-3] # Remove the .py extension
@ -136,6 +141,7 @@ async def load_modules_from_directory(directory):
if pipeline: if pipeline:
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
PIPELINE_MODULES[pipeline_id] = pipeline PIPELINE_MODULES[pipeline_id] = pipeline
PIPELINE_NAMES[pipeline_id] = module_name
logging.info(f"Loaded module: {module_name}") logging.info(f"Loaded module: {module_name}")
else: else:
logging.warning(f"No Pipeline class found in {module_name}") logging.warning(f"No Pipeline class found in {module_name}")
@ -161,7 +167,9 @@ async def on_shutdown():
async def reload(): async def reload():
await on_shutdown() await on_shutdown()
# Clear existing pipelines # Clear existing pipelines
PIPELINES.clear()
PIPELINE_MODULES.clear() PIPELINE_MODULES.clear()
PIPELINE_NAMES.clear()
# Load pipelines afresh # Load pipelines afresh
await on_startup() await on_startup()
@ -231,7 +239,9 @@ async def get_models():
}, },
} }
for pipeline in app.state.PIPELINES.values() for pipeline in app.state.PIPELINES.values()
] ],
"object": "list",
"pipelines": True,
} }
@ -488,7 +498,7 @@ async def get_status():
@app.get("/pipelines") @app.get("/pipelines")
async def list_pipelines(user: str = Depends(get_current_user)): async def list_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY: if user == API_KEY:
return {"data": list(app.state.PIPELINE_MODULES.keys())} return {"data": list(PIPELINE_MODULES.keys())}
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -535,7 +545,10 @@ async def add_pipeline(
) )
try: try:
file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR) url = convert_to_raw_url(form_data.url)
print(url)
file_path = await download_file(url, dest_folder=PIPELINES_DIR)
await reload() await reload()
return { return {
"status": True, "status": True,
@ -566,14 +579,13 @@ async def delete_pipeline(
) )
pipeline_id = form_data.id pipeline_id = form_data.id
pipeline_module = PIPELINE_MODULES.get(pipeline_id, None) pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None)
if pipeline_module: if PIPELINE_MODULES[pipeline_id]:
if hasattr(pipeline_module, "on_shutdown"): if hasattr(PIPELINE_MODULES[pipeline_id], "on_shutdown"):
await pipeline_module.on_shutdown() await PIPELINE_MODULES[pipeline_id].on_shutdown()
pipeline_id = pipeline_module.__name__.split(".")[0]
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py") pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_name}.py")
if os.path.exists(pipeline_path): if os.path.exists(pipeline_path):
os.remove(pipeline_path) os.remove(pipeline_path)
await reload() await reload()

35
utils/misc.py Normal file
View File

@ -0,0 +1,35 @@
import re
def convert_to_raw_url(github_url):
"""
Converts a GitHub URL to a raw URL.
Example:
https://github.com/user/repo/blob/branch/path/to/file.ext
becomes
https://raw.githubusercontent.com/user/repo/branch/path/to/file.ext
Parameters:
github_url (str): The GitHub URL to convert.
Returns:
str: The converted raw URL.
"""
# Define the regular expression pattern
pattern = r"https://github\.com/(.+?)/(.+?)/blob/(.+?)/(.+)"
# Use the pattern to match and extract parts of the URL
match = re.match(pattern, github_url)
if match:
user_repo = match.group(1) + "/" + match.group(2)
branch = match.group(3)
file_path = match.group(4)
# Construct the raw URL
raw_url = f"https://raw.githubusercontent.com/{user_repo}/{branch}/{file_path}"
return raw_url
# If the URL does not match the expected pattern, return the original URL or raise an error
return github_url