mirror of
https://github.com/open-webui/pipelines
synced 2025-05-12 08:30:43 +00:00
refac
This commit is contained in:
parent
21bab00496
commit
6f9c8592c5
70
cli.py
70
cli.py
@ -1,70 +0,0 @@
|
|||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
def start_process(app: str, host: str, port: int, reload: bool = False):
|
|
||||||
# Start the FastAPI application
|
|
||||||
command = [
|
|
||||||
"uvicorn",
|
|
||||||
app,
|
|
||||||
"--host",
|
|
||||||
host,
|
|
||||||
"--port",
|
|
||||||
str(port),
|
|
||||||
"--forwarded-allow-ips",
|
|
||||||
"*",
|
|
||||||
]
|
|
||||||
|
|
||||||
if reload:
|
|
||||||
command.append("--reload")
|
|
||||||
|
|
||||||
process = subprocess.Popen(command)
|
|
||||||
return process
|
|
||||||
|
|
||||||
|
|
||||||
main = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
|
||||||
def serve(
|
|
||||||
host: str = "0.0.0.0",
|
|
||||||
port: int = 9099,
|
|
||||||
):
|
|
||||||
while True:
|
|
||||||
process = start_process("main:app", host, port, reload=False)
|
|
||||||
process.wait()
|
|
||||||
|
|
||||||
if process.returncode == 42:
|
|
||||||
print("Restarting due to restart request")
|
|
||||||
time.sleep(2) # optional delay to prevent tight restart loops
|
|
||||||
else:
|
|
||||||
print("Normal exit, stopping the manager")
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
|
||||||
def dev(
|
|
||||||
host: str = "0.0.0.0",
|
|
||||||
port: int = 9099,
|
|
||||||
):
|
|
||||||
while True:
|
|
||||||
process = start_process("main:app", host, port, reload=True)
|
|
||||||
process.wait()
|
|
||||||
|
|
||||||
if process.returncode == 42:
|
|
||||||
print("Restarting due to restart request")
|
|
||||||
time.sleep(2) # optional delay to prevent tight restart loops
|
|
||||||
else:
|
|
||||||
print("Normal exit, stopping the manager")
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
198
main.py
198
main.py
@ -15,15 +15,16 @@ from utils.main import get_last_user_message, stream_message_template
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from schemas import FilterForm, OpenAIChatCompletionForm
|
from schemas import FilterForm, OpenAIChatCompletionForm
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
@ -37,6 +38,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
print("dotenv not installed, skipping...")
|
print("dotenv not installed, skipping...")
|
||||||
|
|
||||||
|
API_KEY = os.getenv("API_KEY", "0p3n-w3bu!")
|
||||||
|
|
||||||
|
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
|
||||||
|
|
||||||
|
if not os.path.exists(PIPELINES_DIR):
|
||||||
|
os.makedirs(PIPELINES_DIR)
|
||||||
|
|
||||||
|
|
||||||
PIPELINES = {}
|
PIPELINES = {}
|
||||||
PIPELINE_MODULES = {}
|
PIPELINE_MODULES = {}
|
||||||
@ -109,42 +117,60 @@ def get_all_pipelines():
|
|||||||
return pipelines
|
return pipelines
|
||||||
|
|
||||||
|
|
||||||
def on_startup():
|
async def load_module_from_path(module_name, module_path):
|
||||||
def load_modules_from_directory(directory):
|
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||||
for filename in os.listdir(directory):
|
module = importlib.util.module_from_spec(spec)
|
||||||
if filename.endswith(".py"):
|
spec.loader.exec_module(module)
|
||||||
module_name = filename[:-3] # Remove the .py extension
|
if hasattr(module, "Pipeline"):
|
||||||
module_path = os.path.join(directory, filename)
|
return module.Pipeline()
|
||||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
return None
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
yield module
|
|
||||||
|
|
||||||
for loaded_module in load_modules_from_directory("./pipelines"):
|
|
||||||
# Do something with the loaded module
|
|
||||||
logging.info("Loaded:", loaded_module.__name__)
|
|
||||||
|
|
||||||
pipeline = loaded_module.Pipeline()
|
async def load_modules_from_directory(directory):
|
||||||
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__
|
global PIPELINE_MODULES
|
||||||
|
for filename in os.listdir(directory):
|
||||||
PIPELINE_MODULES[pipeline_id] = pipeline
|
if filename.endswith(".py"):
|
||||||
|
module_name = filename[:-3] # Remove the .py extension
|
||||||
|
module_path = os.path.join(directory, filename)
|
||||||
|
pipeline = await load_module_from_path(module_name, module_path)
|
||||||
|
if pipeline:
|
||||||
|
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
|
||||||
|
PIPELINE_MODULES[pipeline_id] = pipeline
|
||||||
|
logging.info(f"Loaded module: {module_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(f"No Pipeline class found in {module_name}")
|
||||||
|
|
||||||
|
global PIPELINES
|
||||||
PIPELINES = get_all_pipelines()
|
PIPELINES = get_all_pipelines()
|
||||||
|
|
||||||
|
|
||||||
on_startup()
|
async def on_startup():
|
||||||
|
await load_modules_from_directory(PIPELINES_DIR)
|
||||||
|
|
||||||
|
for module in PIPELINE_MODULES.values():
|
||||||
|
if hasattr(module, "on_startup"):
|
||||||
|
await module.on_startup()
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown():
|
||||||
|
for module in PIPELINE_MODULES.values():
|
||||||
|
if hasattr(module, "on_shutdown"):
|
||||||
|
await module.on_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def reload():
|
||||||
|
await on_shutdown()
|
||||||
|
# Clear existing pipelines
|
||||||
|
PIPELINE_MODULES.clear()
|
||||||
|
# Load pipelines afresh
|
||||||
|
await on_startup()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
for module in PIPELINE_MODULES.values():
|
await on_startup()
|
||||||
if hasattr(module, "on_startup"):
|
|
||||||
await module.on_startup()
|
|
||||||
yield
|
yield
|
||||||
|
await on_shutdown()
|
||||||
for module in PIPELINE_MODULES.values():
|
|
||||||
if hasattr(module, "on_shutdown"):
|
|
||||||
await module.on_shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
|
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
|
||||||
@ -458,10 +484,118 @@ async def get_status():
|
|||||||
return {"status": True}
|
return {"status": True}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/restart")
|
@app.get("/v1/pipelines")
|
||||||
@app.post("/restart")
|
@app.get("/pipelines")
|
||||||
def restart_server(user: str = Depends(get_current_user)):
|
async def list_pipelines(user: str = Depends(get_current_user)):
|
||||||
|
if user == API_KEY:
|
||||||
|
return {"data": list(app.state.PIPELINE_MODULES.keys())}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
print(user)
|
|
||||||
|
|
||||||
return True
|
class AddPipelineForm(BaseModel):
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
async def download_file(url: str, dest_folder: str):
|
||||||
|
filename = os.path.basename(urlparse(url).path)
|
||||||
|
if not filename.endswith(".py"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="URL must point to a Python file",
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = os.path.join(dest_folder, filename)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to download file",
|
||||||
|
)
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(await response.read())
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/pipelines/add")
|
||||||
|
@app.post("/pipelines/add")
|
||||||
|
async def add_pipeline(
|
||||||
|
form_data: AddPipelineForm, user: str = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
if user != API_KEY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_path = await download_file(form_data.url, dest_folder=PIPELINES_DIR)
|
||||||
|
await reload()
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"detail": f"Pipeline added successfully from {file_path}",
|
||||||
|
}
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeletePipelineForm(BaseModel):
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/v1/pipelines/delete")
|
||||||
|
@app.delete("/pipelines/delete")
|
||||||
|
async def delete_pipeline(
|
||||||
|
form_data: DeletePipelineForm, user: str = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
if user != API_KEY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_id = form_data.id
|
||||||
|
pipeline_module = PIPELINE_MODULES.get(pipeline_id, None)
|
||||||
|
|
||||||
|
if pipeline_module:
|
||||||
|
if hasattr(pipeline_module, "on_shutdown"):
|
||||||
|
await pipeline_module.on_shutdown()
|
||||||
|
pipeline_id = pipeline_module.__name__.split(".")[0]
|
||||||
|
|
||||||
|
pipeline_path = os.path.join(PIPELINES_DIR, f"{pipeline_id}.py")
|
||||||
|
if os.path.exists(pipeline_path):
|
||||||
|
os.remove(pipeline_path)
|
||||||
|
await reload()
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"detail": f"Pipeline {pipeline_id} deleted successfully",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Pipeline {pipeline_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/reload")
|
||||||
|
@app.post("/reload")
|
||||||
|
async def reload_pipelines(user: str = Depends(get_current_user)):
|
||||||
|
if user == API_KEY:
|
||||||
|
await reload()
|
||||||
|
return {"message": "Pipelines reloaded successfully."}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user