feat: dynamic loading endpoints

This commit is contained in:
Timothy J. Baek 2024-05-29 00:34:31 -07:00
parent 6f9c8592c5
commit e34e8c4fc0
2 changed files with 58 additions and 11 deletions

34
main.py
View File

@ -10,7 +10,7 @@ from typing import List, Union, Generator, Iterator
from utils.auth import bearer_security, get_current_user
from utils.main import get_last_user_message, stream_message_template
from utils.misc import convert_to_raw_url
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
@ -38,7 +38,7 @@ try:
except ImportError:
print("dotenv not installed, skipping...")
API_KEY = os.getenv("API_KEY", "0p3n-w3bu!")
API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!")
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
@ -48,6 +48,7 @@ if not os.path.exists(PIPELINES_DIR):
PIPELINES = {}
PIPELINE_MODULES = {}
PIPELINE_NAMES = {}
def get_all_pipelines():
@ -121,6 +122,8 @@ async def load_module_from_path(module_name, module_path):
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipeline"):
return module.Pipeline()
return None
@ -128,6 +131,8 @@ async def load_module_from_path(module_name, module_path):
async def load_modules_from_directory(directory):
global PIPELINE_MODULES
global PIPELINE_NAMES
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
@ -136,6 +141,7 @@ async def load_modules_from_directory(directory):
if pipeline:
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
PIPELINE_MODULES[pipeline_id] = pipeline
PIPELINE_NAMES[pipeline_id] = module_name
logging.info(f"Loaded module: {module_name}")
else:
logging.warning(f"No Pipeline class found in {module_name}")
@ -161,7 +167,9 @@ async def on_shutdown():
async def reload():
await on_shutdown()
# Clear existing pipelines
PIPELINES.clear()
PIPELINE_MODULES.clear()
PIPELINE_NAMES.clear()
# Load pipelines afresh
await on_startup()
@ -231,7 +239,9 @@ async def get_models():
},
}
for pipeline in app.state.PIPELINES.values()
]
],
"object": "list",
"pipelines": True,
}
@ -488,7 +498,7 @@ async def get_status():
@app.get("/pipelines")
async def list_pipelines(user: str = Depends(get_current_user)):
if user == API_KEY:
return {"data": list(app.state.PIPELINE_MODULES.keys())}
return {"data": list(PIPELINE_MODULES.keys())}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -535,7 +545,10 @@ async def add_pipeline(
)
try:
file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR)
url = convert_to_raw_url(form_data.url)
print(url)
file_path = await download_file(url, dest_folder=PIPELINES_DIR)
await reload()
return {
"status": True,
@ -566,14 +579,13 @@ async def delete_pipeline(
)
pipeline_id = form_data.id
pipeline_module = PIPELINE_MODULES.get(pipeline_id, None)
pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None)
if pipeline_module:
if hasattr(pipeline_module, "on_shutdown"):
await pipeline_module.on_shutdown()
pipeline_id = pipeline_module.__name__.split(".")[0]
if PIPELINE_MODULES[pipeline_id]:
if hasattr(PIPELINE_MODULES[pipeline_id], "on_shutdown"):
await PIPELINE_MODULES[pipeline_id].on_shutdown()
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py")
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_name}.py")
if os.path.exists(pipeline_path):
os.remove(pipeline_path)
await reload()

35
utils/misc.py Normal file
View File

@ -0,0 +1,35 @@
import re
def convert_to_raw_url(github_url):
"""
Converts a GitHub URL to a raw URL.
Example:
https://github.com/user/repo/blob/branch/path/to/file.ext
becomes
https://raw.githubusercontent.com/user/repo/branch/path/to/file.ext
Parameters:
github_url (str): The GitHub URL to convert.
Returns:
str: The converted raw URL.
"""
# Define the regular expression pattern
pattern = r"https://github\.com/(.+?)/(.+?)/blob/(.+?)/(.+)"
# Use the pattern to match and extract parts of the URL
match = re.match(pattern, github_url)
if match:
user_repo = match.group(1) + "/" + match.group(2)
branch = match.group(3)
file_path = match.group(4)
# Construct the raw URL
raw_url = f"https://raw.githubusercontent.com/{user_repo}/{branch}/{file_path}"
return raw_url
# If the URL does not match the expected pattern, return the original URL or raise an error
return github_url