From e34e8c4fc0f85017b512ed032693e44ea719c79c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 29 May 2024 00:34:31 -0700 Subject: [PATCH] feat: dynamic loading endpoints --- main.py | 34 +++++++++++++++++++++++----------- utils/misc.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 utils/misc.py diff --git a/main.py b/main.py index cd99d6d..f0c68de 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from typing import List, Union, Generator, Iterator from utils.auth import bearer_security, get_current_user from utils.main import get_last_user_message, stream_message_template - +from utils.misc import convert_to_raw_url from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor @@ -38,7 +38,7 @@ try: except ImportError: print("dotenv not installed, skipping...") -API_KEY = os.getenv("API_KEY", "0p3n-w3bu!") +API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!") PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines") @@ -48,6 +48,7 @@ if not os.path.exists(PIPELINES_DIR): PIPELINES = {} PIPELINE_MODULES = {} +PIPELINE_NAMES = {} def get_all_pipelines(): @@ -121,6 +122,8 @@ async def load_module_from_path(module_name, module_path): spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) + + print(f"Loaded module: {module.__name__}") if hasattr(module, "Pipeline"): return module.Pipeline() return None @@ -128,6 +131,8 @@ async def load_module_from_path(module_name, module_path): async def load_modules_from_directory(directory): global PIPELINE_MODULES + global PIPELINE_NAMES + for filename in os.listdir(directory): if filename.endswith(".py"): module_name = filename[:-3] # Remove the .py extension @@ -136,6 +141,7 @@ async def load_modules_from_directory(directory): if pipeline: pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name PIPELINE_MODULES[pipeline_id] = pipeline + PIPELINE_NAMES[pipeline_id] = module_name logging.info(f"Loaded module: {module_name}") else: logging.warning(f"No Pipeline class found in {module_name}") @@ -161,7 +167,9 @@ async def on_shutdown(): async def reload(): await on_shutdown() # Clear existing pipelines + PIPELINES.clear() PIPELINE_MODULES.clear() + PIPELINE_NAMES.clear() # Load pipelines afresh await on_startup() @@ -231,7 +239,9 @@ async def get_models(): }, } for pipeline in app.state.PIPELINES.values() - ] + ], + "object": "list", + "pipelines": True, } @@ -488,7 +498,7 @@ async def get_status(): @app.get("/pipelines") async def list_pipelines(user: str = Depends(get_current_user)): if user == API_KEY: - return {"data": list(app.state.PIPELINE_MODULES.keys())} + return {"data": list(PIPELINE_MODULES.keys())} else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -535,7 +545,10 @@ async def add_pipeline( ) try: - file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR) + url = convert_to_raw_url(form_data.url) + + print(url) + file_path = await download_file(url, dest_folder=PIPELINES_DIR) await reload() return { "status": True, @@ -566,14 +579,13 @@ async def delete_pipeline( ) pipeline_id = form_data.id - pipeline_module = PIPELINE_MODULES.get(pipeline_id, None) + pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None) - if pipeline_module: - if hasattr(pipeline_module, "on_shutdown"): - await pipeline_module.on_shutdown() - pipeline_id = pipeline_module.__name__.split(".")[0] + if PIPELINE_MODULES[pipeline_id]: + if hasattr(PIPELINE_MODULES[pipeline_id], "on_shutdown"): + await PIPELINE_MODULES[pipeline_id].on_shutdown() - pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py") + pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_name}.py") if os.path.exists(pipeline_path): os.remove(pipeline_path) await reload() diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..e0c1e1c --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,35 @@ +import re + + +def convert_to_raw_url(github_url): + """ + Converts a GitHub URL to a raw URL. + + Example: + https://github.com/user/repo/blob/branch/path/to/file.ext + becomes + https://raw.githubusercontent.com/user/repo/branch/path/to/file.ext + + Parameters: + github_url (str): The GitHub URL to convert. + + Returns: + str: The converted raw URL. + """ + # Define the regular expression pattern + pattern = r"https://github\.com/(.+?)/(.+?)/blob/(.+?)/(.+)" + + # Use the pattern to match and extract parts of the URL + match = re.match(pattern, github_url) + + if match: + user_repo = match.group(1) + "/" + match.group(2) + branch = match.group(3) + file_path = match.group(4) + + # Construct the raw URL + raw_url = f"https://raw.githubusercontent.com/{user_repo}/{branch}/{file_path}" + return raw_url + + # If the URL does not match the expected pattern, return the original URL or raise an error + return github_url