This commit is contained in:
Timothy J. Baek 2024-05-28 11:43:40 -07:00
parent 9200e6e4f7
commit 2bafca58ff
3 changed files with 76 additions and 15 deletions

77
main.py
View File

@ -77,6 +77,9 @@ def on_startup():
"id": manifold_pipeline_id, "id": manifold_pipeline_id,
"name": manifold_pipeline_name, "name": manifold_pipeline_name,
"manifold": True, "manifold": True,
"valves": (
pipeline.valves if hasattr(pipeline, "valves") else None
),
} }
if pipeline.type == "filter": if pipeline.type == "filter":
PIPELINES[pipeline_id] = { PIPELINES[pipeline_id] = {
@ -92,12 +95,14 @@ def on_startup():
"priority": ( "priority": (
pipeline.priority if hasattr(pipeline, "priority") else 0 pipeline.priority if hasattr(pipeline, "priority") else 0
), ),
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
} }
else: else:
PIPELINES[pipeline_id] = { PIPELINES[pipeline_id] = {
"module": pipeline_id, "module": pipeline_id,
"id": pipeline_id, "id": pipeline_id,
"name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id), "name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
} }
@ -160,37 +165,91 @@ async def get_models():
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "openai", "owned_by": "openai",
"pipeline": {
**( **(
{ {
"pipeline": {
"type": ( "type": (
"pipeline" if not pipeline.get("filter") else "filter" "pipeline" if not pipeline.get("filter") else "filter"
), ),
"pipelines": pipeline.get("pipelines", []), "pipelines": pipeline.get("pipelines", []),
"priority": pipeline.get("priority", 0), "priority": pipeline.get("priority", 0),
} }
}
if pipeline.get("filter", False) if pipeline.get("filter", False)
else {} else {}
), ),
"valves": "valves" in pipeline,
},
} }
for pipeline in PIPELINES.values() for pipeline in PIPELINES.values()
] ]
} }
@app.post("/filter") @app.get("/{pipeline_id}/valves")
@app.post("/v1/filter") async def get_valves(pipeline_id: str):
async def filter(form_data: FilterForm): if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
if form_data.model not in app.state.PIPELINES or not app.state.PIPELINES[ pipeline_id
form_data.model ].get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves {pipeline_id} not found",
)
pipeline = PIPELINE_MODULES[pipeline_id]
return pipeline.valves
@app.get("/{pipeline_id}/valves/spec")
async def get_valves_spec(pipeline_id: str):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves {pipeline_id} not found",
)
pipeline = PIPELINE_MODULES[pipeline_id]
return pipeline.valves.schema()
@app.post("/{pipeline_id}/valves/update")
async def update_valves(pipeline_id: str, form_data: dict):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("valves", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valves {pipeline_id} not found",
)
pipeline = PIPELINE_MODULES[pipeline_id]
try:
ValvesModel = pipeline.valves.__class__
valves = ValvesModel(**form_data.valves)
pipeline.valves = valves
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"{str(e)}",
)
return pipeline.valves
@app.post("/{pipeline_id}/filter")
async def filter(pipeline_id: str, form_data: FilterForm):
if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[
pipeline_id
].get("filter", False): ].get("filter", False):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"filter {form_data.model} not found", detail=f"filter {pipeline_id} not found",
) )
pipeline = PIPELINE_MODULES[form_data.model] pipeline = PIPELINE_MODULES[pipeline_id]
try: try:
body = await pipeline.filter(form_data.body, form_data.user) body = await pipeline.filter(form_data.body, form_data.user)

View File

@ -10,6 +10,9 @@ class Pipeline:
# You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. # You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API.
self.type = "filter" self.type = "filter"
# Assign a unique identifier to the filter pipeline.
# The identifier must be unique across all filter pipelines.
# The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
self.id = "rate_limit_filter_pipeline" self.id = "rate_limit_filter_pipeline"
self.name = "Rate Limit Filter" self.name = "Rate Limit Filter"

View File

@ -18,7 +18,6 @@ class OpenAIChatCompletionForm(BaseModel):
class FilterForm(BaseModel): class FilterForm(BaseModel):
model: str
body: dict body: dict
user: Optional[dict] = None user: Optional[dict] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")