mirror of
https://github.com/open-webui/pipelines
synced 2025-05-12 00:20:48 +00:00
feat: dynamic loading endpoints
This commit is contained in:
parent
6f9c8592c5
commit
e34e8c4fc0
34
main.py
34
main.py
@ -10,7 +10,7 @@ from typing import List, Union, Generator, Iterator
|
||||
|
||||
from utils.auth import bearer_security, get_current_user
|
||||
from utils.main import get_last_user_message, stream_message_template
|
||||
|
||||
from utils.misc import convert_to_raw_url
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -38,7 +38,7 @@ try:
|
||||
except ImportError:
|
||||
print("dotenv not installed, skipping...")
|
||||
|
||||
API_KEY = os.getenv("API_KEY", "0p3n-w3bu!")
|
||||
API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!")
|
||||
|
||||
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
|
||||
|
||||
@ -48,6 +48,7 @@ if not os.path.exists(PIPELINES_DIR):
|
||||
|
||||
PIPELINES = {}
|
||||
PIPELINE_MODULES = {}
|
||||
PIPELINE_NAMES = {}
|
||||
|
||||
|
||||
def get_all_pipelines():
|
||||
@ -121,6 +122,8 @@ async def load_module_from_path(module_name, module_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
print(f"Loaded module: {module.__name__}")
|
||||
if hasattr(module, "Pipeline"):
|
||||
return module.Pipeline()
|
||||
return None
|
||||
@ -128,6 +131,8 @@ async def load_module_from_path(module_name, module_path):
|
||||
|
||||
async def load_modules_from_directory(directory):
|
||||
global PIPELINE_MODULES
|
||||
global PIPELINE_NAMES
|
||||
|
||||
for filename in os.listdir(directory):
|
||||
if filename.endswith(".py"):
|
||||
module_name = filename[:-3] # Remove the .py extension
|
||||
@ -136,6 +141,7 @@ async def load_modules_from_directory(directory):
|
||||
if pipeline:
|
||||
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
|
||||
PIPELINE_MODULES[pipeline_id] = pipeline
|
||||
PIPELINE_NAMES[pipeline_id] = module_name
|
||||
logging.info(f"Loaded module: {module_name}")
|
||||
else:
|
||||
logging.warning(f"No Pipeline class found in {module_name}")
|
||||
@ -161,7 +167,9 @@ async def on_shutdown():
|
||||
async def reload():
|
||||
await on_shutdown()
|
||||
# Clear existing pipelines
|
||||
PIPELINES.clear()
|
||||
PIPELINE_MODULES.clear()
|
||||
PIPELINE_NAMES.clear()
|
||||
# Load pipelines afresh
|
||||
await on_startup()
|
||||
|
||||
@ -231,7 +239,9 @@ async def get_models():
|
||||
},
|
||||
}
|
||||
for pipeline in app.state.PIPELINES.values()
|
||||
]
|
||||
],
|
||||
"object": "list",
|
||||
"pipelines": True,
|
||||
}
|
||||
|
||||
|
||||
@ -488,7 +498,7 @@ async def get_status():
|
||||
@app.get("/pipelines")
|
||||
async def list_pipelines(user: str = Depends(get_current_user)):
|
||||
if user == API_KEY:
|
||||
return {"data": list(app.state.PIPELINE_MODULES.keys())}
|
||||
return {"data": list(PIPELINE_MODULES.keys())}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -535,7 +545,10 @@ async def add_pipeline(
|
||||
)
|
||||
|
||||
try:
|
||||
file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR)
|
||||
url = convert_to_raw_url(form_data.url)
|
||||
|
||||
print(url)
|
||||
file_path = await download_file(url, dest_folder=PIPELINES_DIR)
|
||||
await reload()
|
||||
return {
|
||||
"status": True,
|
||||
@ -566,14 +579,13 @@ async def delete_pipeline(
|
||||
)
|
||||
|
||||
pipeline_id = form_data.id
|
||||
pipeline_module = PIPELINE_MODULES.get(pipeline_id, None)
|
||||
pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None)
|
||||
|
||||
if pipeline_module:
|
||||
if hasattr(pipeline_module, "on_shutdown"):
|
||||
await pipeline_module.on_shutdown()
|
||||
pipeline_id = pipeline_module.__name__.split(".")[0]
|
||||
if PIPELINE_MODULES[pipeline_id]:
|
||||
if hasattr(PIPELINE_MODULES[pipeline_id], "on_shutdown"):
|
||||
await PIPELINE_MODULES[pipeline_id].on_shutdown()
|
||||
|
||||
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py")
|
||||
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_name}.py")
|
||||
if os.path.exists(pipeline_path):
|
||||
os.remove(pipeline_path)
|
||||
await reload()
|
||||
|
35
utils/misc.py
Normal file
35
utils/misc.py
Normal file
@ -0,0 +1,35 @@
|
||||
import re
|
||||
|
||||
|
||||
def convert_to_raw_url(github_url):
|
||||
"""
|
||||
Converts a GitHub URL to a raw URL.
|
||||
|
||||
Example:
|
||||
https://github.com/user/repo/blob/branch/path/to/file.ext
|
||||
becomes
|
||||
https://raw.githubusercontent.com/user/repo/branch/path/to/file.ext
|
||||
|
||||
Parameters:
|
||||
github_url (str): The GitHub URL to convert.
|
||||
|
||||
Returns:
|
||||
str: The converted raw URL.
|
||||
"""
|
||||
# Define the regular expression pattern
|
||||
pattern = r"https://github\.com/(.+?)/(.+?)/blob/(.+?)/(.+)"
|
||||
|
||||
# Use the pattern to match and extract parts of the URL
|
||||
match = re.match(pattern, github_url)
|
||||
|
||||
if match:
|
||||
user_repo = match.group(1) + "/" + match.group(2)
|
||||
branch = match.group(3)
|
||||
file_path = match.group(4)
|
||||
|
||||
# Construct the raw URL
|
||||
raw_url = f"https://raw.githubusercontent.com/{user_repo}/{branch}/{file_path}"
|
||||
return raw_url
|
||||
|
||||
# If the URL does not match the expected pattern, return the original URL or raise an error
|
||||
return github_url
|
Loading…
Reference in New Issue
Block a user