pipelines/main.py
2024-05-28 14:59:42 -07:00

404 lines
13 KiB
Python

from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator
import time
import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import FilterForm, OpenAIChatCompletionForm
import os
import importlib.util
import logging
from concurrent.futures import ThreadPoolExecutor
import os
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(".env"))
except ImportError:
print("dotenv not installed, skipping...")
PIPELINES = {}
PIPELINE_MODULES = {}
def on_startup():
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__
PIPELINE_MODULES[pipeline_id] = pipeline
if hasattr(pipeline, "type"):
if pipeline.type == "manifold":
for p in pipeline.pipelines:
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
manifold_pipeline_name = p["name"]
if hasattr(pipeline, "name"):
manifold_pipeline_name = (
f"{pipeline.name}{manifold_pipeline_name}"
)
PIPELINES[manifold_pipeline_id] = {
"module": pipeline_id,
"type": pipeline.type if hasattr(pipeline, "type") else "pipe",
"id": manifold_pipeline_id,
"name": manifold_pipeline_name,
"valves": (
pipeline.valves if hasattr(pipeline, "valves") else None
),
}
if pipeline.type == "filter":
PIPELINES[pipeline_id] = {
"module": pipeline_id,
"type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
"id": pipeline_id,
"name": (
pipeline.name if hasattr(pipeline, "name") else pipeline_id
),
"pipelines": (
pipeline.valves.pipelines
if hasattr(pipeline, "valves")
and hasattr(pipeline.valves, "pipelines")
else []
),
"priority": (
pipeline.valves.priority
if hasattr(pipeline, "valves")
and hasattr(pipeline.valves, "priority")
else 0
),
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
}
else:
PIPELINES[pipeline_id] = {
"module": pipeline_id,
"type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
"id": pipeline_id,
"name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
}
on_startup()
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_startup"):
await module.on_startup()
yield
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
app.state.PIPELINES = PIPELINES
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/models")
@app.get("/v1/models")
async def get_models():
"""
Returns the available pipelines
"""
return {
"data": [
{
"id": pipeline["id"],
"name": pipeline["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
"pipeline": {
"type": pipeline["type"],
**(
{
"pipelines": pipeline.get("pipelines", []),
"priority": pipeline.get("priority", 0),
}
if pipeline.get("type", "pipe") == "filter"
else {}
),
"valves": pipeline["valves"] != None,
},
}
for pipeline in PIPELINES.values()
]
}
@app.get("/{pipeline_id}/valves")
async def get_valves(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves {pipeline_id} not found",
)
pipeline = PIPELINE_MODULES[pipeline_id]
return pipeline.valves
@app.get("/{pipeline_id}/valves/spec")
async def get_valves_spec(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves {pipeline_id} not found",
)
pipeline = PIPELINE_MODULES[pipeline_id]
return pipeline.valves.schema()
@app.post("/{pipeline_id}/valves/update")
async def update_valves(pipeline_id: str, form_data: dict):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves {pipeline_id} not found",
)
pipeline = PIPELINE_MODULES[pipeline_id]
try:
ValvesModel = pipeline.valves.__class__
valves = ValvesModel(**form_data)
pipeline.valves = valves
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"{str(e)}",
)
return pipeline.valves
@app.post("/{pipeline_id}/filter")
async def filter(pipeline_id: str, form_data: FilterForm):
if (
pipeline_id not in app.state.PIPELINES
or app.state.PIPELINES[pipeline_id].get("type", "pipe") != "filter"
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Filter {pipeline_id} not found",
)
pipeline = PIPELINE_MODULES[pipeline_id]
try:
body = await pipeline.filter(form_data.body, form_data.user)
return body
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"{str(e)}",
)
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages)
messages = [message.model_dump() for message in form_data.messages]
if form_data.model not in app.state.PIPELINES or app.state.PIPELINES[
form_data.model
].get("filter", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found",
)
def job():
print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model]
pipeline_id = form_data.model
print(pipeline_id)
if pipeline.get("manifold", False):
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipe = PIPELINE_MODULES[manifold_id].pipe
else:
pipe = PIPELINE_MODULES[pipeline_id].pipe
if form_data.stream:
def stream_content():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
@app.get("/")
async def get_status():
return {"status": True}