refac: valves -> filters

This commit is contained in:
Timothy J. Baek
2024-05-27 19:34:23 -07:00
parent 72749845dc
commit 4eabb0f4a4
21 changed files with 208 additions and 113 deletions

87
main.py
View File

@@ -13,7 +13,7 @@ import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import ValveForm, OpenAIChatCompletionForm
from schemas import FilterForm, OpenAIChatCompletionForm
import os
import importlib.util
@@ -61,30 +61,43 @@ def on_startup():
PIPELINE_MODULES[pipeline_id] = pipeline
if hasattr(pipeline, "manifold") and pipeline.manifold:
for p in pipeline.pipelines:
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
if hasattr(pipeline, "type"):
if pipeline.type == "manifold":
for p in pipeline.pipelines:
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
manifold_pipeline_name = p["name"]
if hasattr(pipeline, "name"):
manifold_pipeline_name = f"{pipeline.name}{manifold_pipeline_name}"
manifold_pipeline_name = p["name"]
if hasattr(pipeline, "name"):
manifold_pipeline_name = (
f"{pipeline.name}{manifold_pipeline_name}"
)
PIPELINES[manifold_pipeline_id] = {
PIPELINES[manifold_pipeline_id] = {
"module": pipeline_id,
"id": manifold_pipeline_id,
"name": manifold_pipeline_name,
"manifold": True,
}
if pipeline.type == "filter":
PIPELINES[pipeline_id] = {
"module": pipeline_id,
"id": manifold_pipeline_id,
"name": manifold_pipeline_name,
"manifold": True,
"id": pipeline_id,
"name": (
pipeline.name if hasattr(pipeline, "name") else pipeline_id
),
"filter": True,
"pipelines": (
pipeline.pipelines if hasattr(pipeline, "pipelines") else []
),
"priority": (
pipeline.priority if hasattr(pipeline, "priority") else 0
),
}
else:
PIPELINES[loaded_module.__name__] = {
PIPELINES[pipeline_id] = {
"module": pipeline_id,
"id": pipeline_id,
"name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
"valve": hasattr(pipeline, "valve"),
"pipelines": (
pipeline.pipelines if hasattr(pipeline, "pipelines") else []
),
"priority": pipeline.priority if hasattr(pipeline, "priority") else 0,
}
@@ -147,30 +160,38 @@ async def get_models():
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
"pipeline": {
"type": "pipeline" if not pipeline.get("valve") else "valve",
"pipelines": pipeline.get("pipelines", []),
"priority": pipeline.get("priority", 0),
},
**(
{
"pipeline": {
"type": (
"pipeline" if not pipeline.get("filter") else "filter"
),
"pipelines": pipeline.get("pipelines", []),
"priority": pipeline.get("priority", 0),
}
}
if pipeline.get("filter", False)
else {}
),
}
for pipeline in PIPELINES.values()
]
}
@app.post("/valve")
@app.post("/v1/valve")
async def valve(form_data: ValveForm):
@app.post("/filter")
@app.post("/v1/filter")
async def filter(form_data: FilterForm):
if form_data.model not in app.state.PIPELINES or not app.state.PIPELINES[
form_data.model
].get("valve", False):
].get("filter", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valve {form_data.model} not found",
detail=f"filter {form_data.model} not found",
)
pipeline = PIPELINE_MODULES[form_data.model]
return await pipeline.control_valve(form_data.body, form_data.user)
return await pipeline.filter(form_data.body, form_data.user)
@app.post("/chat/completions")
@@ -181,7 +202,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if form_data.model not in app.state.PIPELINES or app.state.PIPELINES[
form_data.model
].get("valve", False):
].get("filter", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found",
@@ -197,14 +218,14 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if pipeline.get("manifold", False):
manifold_id, pipeline_id = pipeline_id.split(".", 1)
get_response = PIPELINE_MODULES[manifold_id].get_response
pipe = PIPELINE_MODULES[manifold_id].pipe
else:
get_response = PIPELINE_MODULES[pipeline_id].get_response
pipe = PIPELINE_MODULES[pipeline_id].pipe
if form_data.stream:
def stream_content():
res = get_response(
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
@@ -258,7 +279,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,