mirror of
https://github.com/open-webui/pipelines
synced 2025-05-15 09:55:45 +00:00
feat: dynamic loading endpoints
This commit is contained in:
parent
6f9c8592c5
commit
e34e8c4fc0
34
main.py
34
main.py
@ -10,7 +10,7 @@ from typing import List, Union, Generator, Iterator
|
|||||||
|
|
||||||
from utils.auth import bearer_security, get_current_user
|
from utils.auth import bearer_security, get_current_user
|
||||||
from utils.main import get_last_user_message, stream_message_template
|
from utils.main import get_last_user_message, stream_message_template
|
||||||
|
from utils.misc import convert_to_raw_url
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@ -38,7 +38,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
print("dotenv not installed, skipping...")
|
print("dotenv not installed, skipping...")
|
||||||
|
|
||||||
API_KEY = os.getenv("API_KEY", "0p3n-w3bu!")
|
API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!")
|
||||||
|
|
||||||
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
|
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
|
||||||
|
|
||||||
@ -48,6 +48,7 @@ if not os.path.exists(PIPELINES_DIR):
|
|||||||
|
|
||||||
PIPELINES = {}
|
PIPELINES = {}
|
||||||
PIPELINE_MODULES = {}
|
PIPELINE_MODULES = {}
|
||||||
|
PIPELINE_NAMES = {}
|
||||||
|
|
||||||
|
|
||||||
def get_all_pipelines():
|
def get_all_pipelines():
|
||||||
@ -121,6 +122,8 @@ async def load_module_from_path(module_name, module_path):
|
|||||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
print(f"Loaded module: {module.__name__}")
|
||||||
if hasattr(module, "Pipeline"):
|
if hasattr(module, "Pipeline"):
|
||||||
return module.Pipeline()
|
return module.Pipeline()
|
||||||
return None
|
return None
|
||||||
@ -128,6 +131,8 @@ async def load_module_from_path(module_name, module_path):
|
|||||||
|
|
||||||
async def load_modules_from_directory(directory):
|
async def load_modules_from_directory(directory):
|
||||||
global PIPELINE_MODULES
|
global PIPELINE_MODULES
|
||||||
|
global PIPELINE_NAMES
|
||||||
|
|
||||||
for filename in os.listdir(directory):
|
for filename in os.listdir(directory):
|
||||||
if filename.endswith(".py"):
|
if filename.endswith(".py"):
|
||||||
module_name = filename[:-3] # Remove the .py extension
|
module_name = filename[:-3] # Remove the .py extension
|
||||||
@ -136,6 +141,7 @@ async def load_modules_from_directory(directory):
|
|||||||
if pipeline:
|
if pipeline:
|
||||||
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
|
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
|
||||||
PIPELINE_MODULES[pipeline_id] = pipeline
|
PIPELINE_MODULES[pipeline_id] = pipeline
|
||||||
|
PIPELINE_NAMES[pipeline_id] = module_name
|
||||||
logging.info(f"Loaded module: {module_name}")
|
logging.info(f"Loaded module: {module_name}")
|
||||||
else:
|
else:
|
||||||
logging.warning(f"No Pipeline class found in {module_name}")
|
logging.warning(f"No Pipeline class found in {module_name}")
|
||||||
@ -161,7 +167,9 @@ async def on_shutdown():
|
|||||||
async def reload():
|
async def reload():
|
||||||
await on_shutdown()
|
await on_shutdown()
|
||||||
# Clear existing pipelines
|
# Clear existing pipelines
|
||||||
|
PIPELINES.clear()
|
||||||
PIPELINE_MODULES.clear()
|
PIPELINE_MODULES.clear()
|
||||||
|
PIPELINE_NAMES.clear()
|
||||||
# Load pipelines afresh
|
# Load pipelines afresh
|
||||||
await on_startup()
|
await on_startup()
|
||||||
|
|
||||||
@ -231,7 +239,9 @@ async def get_models():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
for pipeline in app.state.PIPELINES.values()
|
for pipeline in app.state.PIPELINES.values()
|
||||||
]
|
],
|
||||||
|
"object": "list",
|
||||||
|
"pipelines": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -488,7 +498,7 @@ async def get_status():
|
|||||||
@app.get("/pipelines")
|
@app.get("/pipelines")
|
||||||
async def list_pipelines(user: str = Depends(get_current_user)):
|
async def list_pipelines(user: str = Depends(get_current_user)):
|
||||||
if user == API_KEY:
|
if user == API_KEY:
|
||||||
return {"data": list(app.state.PIPELINE_MODULES.keys())}
|
return {"data": list(PIPELINE_MODULES.keys())}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -535,7 +545,10 @@ async def add_pipeline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR)
|
url = convert_to_raw_url(form_data.url)
|
||||||
|
|
||||||
|
print(url)
|
||||||
|
file_path = await download_file(url, dest_folder=PIPELINES_DIR)
|
||||||
await reload()
|
await reload()
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
@ -566,14 +579,13 @@ async def delete_pipeline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline_id = form_data.id
|
pipeline_id = form_data.id
|
||||||
pipeline_module = PIPELINE_MODULES.get(pipeline_id, None)
|
pipeline_name = PIPELINE_NAMES.get(pipeline_id.split(".")[0], None)
|
||||||
|
|
||||||
if pipeline_module:
|
if PIPELINE_MODULES[pipeline_id]:
|
||||||
if hasattr(pipeline_module, "on_shutdown"):
|
if hasattr(PIPELINE_MODULES[pipeline_id], "on_shutdown"):
|
||||||
await pipeline_module.on_shutdown()
|
await PIPELINE_MODULES[pipeline_id].on_shutdown()
|
||||||
pipeline_id = pipeline_module.__name__.split(".")[0]
|
|
||||||
|
|
||||||
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py")
|
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_name}.py")
|
||||||
if os.path.exists(pipeline_path):
|
if os.path.exists(pipeline_path):
|
||||||
os.remove(pipeline_path)
|
os.remove(pipeline_path)
|
||||||
await reload()
|
await reload()
|
||||||
|
35
utils/misc.py
Normal file
35
utils/misc.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_raw_url(github_url):
|
||||||
|
"""
|
||||||
|
Converts a GitHub URL to a raw URL.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
https://github.com/user/repo/blob/branch/path/to/file.ext
|
||||||
|
becomes
|
||||||
|
https://raw.githubusercontent.com/user/repo/branch/path/to/file.ext
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
github_url (str): The GitHub URL to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The converted raw URL.
|
||||||
|
"""
|
||||||
|
# Define the regular expression pattern
|
||||||
|
pattern = r"https://github\.com/(.+?)/(.+?)/blob/(.+?)/(.+)"
|
||||||
|
|
||||||
|
# Use the pattern to match and extract parts of the URL
|
||||||
|
match = re.match(pattern, github_url)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
user_repo = match.group(1) + "/" + match.group(2)
|
||||||
|
branch = match.group(3)
|
||||||
|
file_path = match.group(4)
|
||||||
|
|
||||||
|
# Construct the raw URL
|
||||||
|
raw_url = f"https://raw.githubusercontent.com/{user_repo}/{branch}/{file_path}"
|
||||||
|
return raw_url
|
||||||
|
|
||||||
|
# If the URL does not match the expected pattern, return the original URL or raise an error
|
||||||
|
return github_url
|
Loading…
Reference in New Issue
Block a user