pipelines/main.py
Timothy J. Baek 91f58c52e5 feat: valves
2024-05-27 19:04:35 -07:00

309 lines
9.5 KiB
Python

from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator
import time
import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import ValveForm, OpenAIChatCompletionForm
import os
import importlib.util
import logging
from concurrent.futures import ThreadPoolExecutor
import os
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(".env"))
except ImportError:
print("dotenv not installed, skipping...")
PIPELINES = {}
PIPELINE_MODULES = {}
def on_startup():
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__
PIPELINE_MODULES[pipeline_id] = pipeline
if hasattr(pipeline, "manifold") and pipeline.manifold:
for p in pipeline.pipelines:
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
manifold_pipeline_name = p["name"]
if hasattr(pipeline, "name"):
manifold_pipeline_name = f"{pipeline.name}{manifold_pipeline_name}"
PIPELINES[manifold_pipeline_id] = {
"module": pipeline_id,
"id": manifold_pipeline_id,
"name": manifold_pipeline_name,
"manifold": True,
}
else:
PIPELINES[loaded_module.__name__] = {
"module": pipeline_id,
"id": pipeline_id,
"name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
"valve": hasattr(pipeline, "valve"),
"pipelines": (
pipeline.pipelines if hasattr(pipeline, "pipelines") else []
),
"priority": pipeline.priority if hasattr(pipeline, "priority") else 0,
}
on_startup()
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_startup"):
await module.on_startup()
yield
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
app.state.PIPELINES = PIPELINES
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/models")
@app.get("/v1/models")
async def get_models():
"""
Returns the available pipelines
"""
return {
"data": [
{
"id": pipeline["id"],
"name": pipeline["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
"pipeline": {
"type": "pipeline" if not pipeline.get("valve") else "valve",
"pipelines": pipeline.get("pipelines", []),
"priority": pipeline.get("priority", 0),
},
}
for pipeline in PIPELINES.values()
]
}
@app.post("/valve")
@app.post("/v1/valve")
async def valve(form_data: ValveForm):
if form_data.model not in app.state.PIPELINES or not app.state.PIPELINES[
form_data.model
].get("valve", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valve {form_data.model} not found",
)
pipeline = PIPELINE_MODULES[form_data.model]
return await pipeline.control_valve(form_data.body)
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages)
messages = [message.model_dump() for message in form_data.messages]
if form_data.model not in app.state.PIPELINES or app.state.PIPELINES[
form_data.model
].get("valve", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found",
)
def job():
print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model]
pipeline_id = form_data.model
print(pipeline_id)
if pipeline.get("manifold", False):
manifold_id, pipeline_id = pipeline_id.split(".", 1)
get_response = PIPELINE_MODULES[manifold_id].get_response
else:
get_response = PIPELINE_MODULES[pipeline_id].get_response
if form_data.stream:
def stream_content():
res = get_response(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
@app.get("/")
async def get_status():
return {"status": True}