refac: valves -> filters

This commit is contained in:
Timothy J. Baek 2024-05-27 19:34:23 -07:00
parent 72749845dc
commit 4eabb0f4a4
21 changed files with 208 additions and 113 deletions

87
main.py
View File

@ -13,7 +13,7 @@ import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import ValveForm, OpenAIChatCompletionForm
from schemas import FilterForm, OpenAIChatCompletionForm
import os
import importlib.util
@ -61,30 +61,43 @@ def on_startup():
PIPELINE_MODULES[pipeline_id] = pipeline
if hasattr(pipeline, "manifold") and pipeline.manifold:
for p in pipeline.pipelines:
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
if hasattr(pipeline, "type"):
if pipeline.type == "manifold":
for p in pipeline.pipelines:
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
manifold_pipeline_name = p["name"]
if hasattr(pipeline, "name"):
manifold_pipeline_name = f"{pipeline.name}{manifold_pipeline_name}"
manifold_pipeline_name = p["name"]
if hasattr(pipeline, "name"):
manifold_pipeline_name = (
f"{pipeline.name}{manifold_pipeline_name}"
)
PIPELINES[manifold_pipeline_id] = {
PIPELINES[manifold_pipeline_id] = {
"module": pipeline_id,
"id": manifold_pipeline_id,
"name": manifold_pipeline_name,
"manifold": True,
}
if pipeline.type == "filter":
PIPELINES[pipeline_id] = {
"module": pipeline_id,
"id": manifold_pipeline_id,
"name": manifold_pipeline_name,
"manifold": True,
"id": pipeline_id,
"name": (
pipeline.name if hasattr(pipeline, "name") else pipeline_id
),
"filter": True,
"pipelines": (
pipeline.pipelines if hasattr(pipeline, "pipelines") else []
),
"priority": (
pipeline.priority if hasattr(pipeline, "priority") else 0
),
}
else:
PIPELINES[loaded_module.__name__] = {
PIPELINES[pipeline_id] = {
"module": pipeline_id,
"id": pipeline_id,
"name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
"valve": hasattr(pipeline, "valve"),
"pipelines": (
pipeline.pipelines if hasattr(pipeline, "pipelines") else []
),
"priority": pipeline.priority if hasattr(pipeline, "priority") else 0,
}
@ -147,30 +160,38 @@ async def get_models():
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
"pipeline": {
"type": "pipeline" if not pipeline.get("valve") else "valve",
"pipelines": pipeline.get("pipelines", []),
"priority": pipeline.get("priority", 0),
},
**(
{
"pipeline": {
"type": (
"pipeline" if not pipeline.get("filter") else "filter"
),
"pipelines": pipeline.get("pipelines", []),
"priority": pipeline.get("priority", 0),
}
}
if pipeline.get("filter", False)
else {}
),
}
for pipeline in PIPELINES.values()
]
}
@app.post("/valve")
@app.post("/v1/valve")
async def valve(form_data: ValveForm):
@app.post("/filter")
@app.post("/v1/filter")
async def filter(form_data: FilterForm):
if form_data.model not in app.state.PIPELINES or not app.state.PIPELINES[
form_data.model
].get("valve", False):
].get("filter", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Valve {form_data.model} not found",
detail=f"filter {form_data.model} not found",
)
pipeline = PIPELINE_MODULES[form_data.model]
return await pipeline.control_valve(form_data.body, form_data.user)
return await pipeline.filter(form_data.body, form_data.user)
@app.post("/chat/completions")
@ -181,7 +202,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if form_data.model not in app.state.PIPELINES or app.state.PIPELINES[
form_data.model
].get("valve", False):
].get("filter", False):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found",
@ -197,14 +218,14 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if pipeline.get("manifold", False):
manifold_id, pipeline_id = pipeline_id.split(".", 1)
get_response = PIPELINE_MODULES[manifold_id].get_response
pipe = PIPELINE_MODULES[manifold_id].pipe
else:
get_response = PIPELINE_MODULES[pipeline_id].get_response
pipe = PIPELINE_MODULES[pipeline_id].pipe
if form_data.stream:
def stream_content():
res = get_response(
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
@ -258,7 +279,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,

View File

@ -24,11 +24,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
OLLAMA_BASE_URL = "http://localhost:11434"
MODEL = "llama3"

View File

@ -20,11 +20,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
print(messages)
print(user_message)

View File

@ -4,18 +4,19 @@ from schemas import OpenAIChatMessage
class Pipeline:
def __init__(self):
# Pipeline valves are only compatible with Open WebUI
# You can think of valve pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API.
self.valve = True
self.id = "valve_pipeline"
self.name = "Valve"
# Pipeline filters are only compatible with Open WebUI
# You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API.
self.type = "filter"
# Assign a priority level to the valve pipeline.
# The priority level determines the order in which the valve pipelines are executed.
self.id = "filter_pipeline"
self.name = "Filter"
# Assign a priority level to the filter pipeline.
# The priority level determines the order in which the filter pipelines are executed.
# The lower the number, the higher the priority.
self.priority = 0
# List target pipelines (models) that this valve will be connected to.
# List target pipelines (models) that this filter will be connected to.
self.pipelines = [
{"id": "llama3:latest"},
]
@ -31,8 +32,8 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
async def control_valve(self, body: dict, user: Optional[dict] = None) -> dict:
print(f"get_response:{__name__}")
async def filter(self, body: dict, user: Optional[dict] = None) -> dict:
print(f"pipe:{__name__}")
print(body)
print(user)

View File

@ -78,7 +78,7 @@ class Pipeline:
# This function is called when the server is stopped.
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline.

View File

@ -29,11 +29,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
print(messages)
print(user_message)

View File

@ -69,7 +69,7 @@ class Pipeline:
# This function is called when the server is stopped.
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline.

View File

@ -29,7 +29,7 @@ class Pipeline:
# This function is called when the server is stopped.
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline.

View File

@ -24,7 +24,7 @@ class Pipeline:
# This function is called when the server is stopped.
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline.

View File

@ -7,7 +7,7 @@ class Pipeline:
# You can also set the pipelines that are available in this pipeline.
# Set manifold to True if you want to use this pipeline as a manifold.
# Manifold pipelines can have multiple pipelines.
self.manifold = True
self.type = "manifold"
self.id = "manifold_pipeline"
# Optionally, you can set the name of the manifold pipeline.
@ -34,11 +34,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
print(messages)
print(user_message)

View File

@ -17,6 +17,7 @@ import subprocess
import logging
from huggingface_hub import login
class Pipeline:
def __init__(self):
self.id = "mlx_pipeline"
@ -24,7 +25,9 @@ class Pipeline:
self.host = os.getenv("MLX_HOST", "localhost")
self.port = os.getenv("MLX_PORT", "8080")
self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(",") # Default stop sequence is [INST]
self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split(
","
) # Default stop sequence is [INST]
self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true"
self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None)
@ -43,8 +46,9 @@ class Pipeline:
def find_free_port(self):
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
@ -53,14 +57,14 @@ class Pipeline:
logging.info(f"on_startup:{__name__}")
async def on_shutdown(self):
if self.subprocess and hasattr(self, 'server_process'):
if self.subprocess and hasattr(self, "server_process"):
self.server_process.terminate()
logging.info(f"Terminated MLX server on port {self.port}")
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
logging.info(f"get_response:{__name__}")
logging.info(f"pipe:{__name__}")
url = f"http://{self.host}:{self.port}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
@ -84,11 +88,13 @@ class Pipeline:
"temperature": temperature,
"repetition_penalty": repeat_penalty,
"stop": self.stop_sequence,
"stream": body.get("stream", False)
"stream": body.get("stream", False),
}
try:
r = requests.post(url, headers=headers, json=payload, stream=body.get("stream", False))
r = requests.post(
url, headers=headers, json=payload, stream=body.get("stream", False)
)
r.raise_for_status()
if body.get("stream", False):

View File

@ -8,7 +8,7 @@ class Pipeline:
# You can also set the pipelines that are available in this pipeline.
# Set manifold to True if you want to use this pipeline as a manifold.
# Manifold pipelines can have multiple pipelines.
self.manifold = True
self.type = "manifold"
self.id = "ollama_manifold"
# Optionally, you can set the name of the manifold pipeline.
self.name = "Ollama: "
@ -35,7 +35,7 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'

View File

@ -20,11 +20,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
OLLAMA_BASE_URL = "http://localhost:11434"
MODEL = "llama3"

View File

@ -20,11 +20,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
print(messages)
print(user_message)

View File

@ -19,11 +19,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
print(messages)
print(user_message)

View File

@ -30,11 +30,11 @@ class Pipeline:
except subprocess.CalledProcessError as e:
return e.output.strip(), e.returncode
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
print(messages)
print(user_message)

View File

@ -1,37 +0,0 @@
from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage
class Pipeline:
def __init__(self):
# Pipeline valves are only compatible with Open WebUI
# You can think of valve pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API.
self.valve = True
self.id = "valve_pipeline"
self.name = "Valve"
# Assign a priority level to the valve pipeline.
# The priority level determines the order in which the valve pipelines are executed.
# The lower the number, the higher the priority.
self.priority = 0
# List target pipelines (models) that this valve will be connected to.
self.pipelines = [
{"id": "llama3:latest"},
]
pass
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
async def control_valve(self, body: dict) -> dict:
print(f"get_response:{__name__}")
print(body)
return body

View File

@ -0,0 +1,41 @@
from typing import List, Optional
from schemas import OpenAIChatMessage
class Pipeline:
def __init__(self):
# Pipeline filters are only compatible with Open WebUI
# You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API.
self.type = "filter"
self.id = "filter_pipeline"
self.name = "Filter"
# Assign a priority level to the filter pipeline.
# The priority level determines the order in which the filter pipelines are executed.
# The lower the number, the higher the priority.
self.priority = 0
# List target pipelines (models) that this filter will be connected to.
self.pipelines = [
{"id": "llama3:latest"},
]
pass
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
async def filter(self, body: dict, user: Optional[dict] = None) -> dict:
print(f"pipe:{__name__}")
print(body)
print(user)
return body

View File

@ -0,0 +1,63 @@
from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage
import requests
class Pipeline:
def __init__(self):
# You can also set the pipelines that are available in this pipeline.
# Set manifold to True if you want to use this pipeline as a manifold.
# Manifold pipelines can have multiple pipelines.
self.type = "manifold"
self.id = "ollama_manifold"
# Optionally, you can set the name of the manifold pipeline.
self.name = "Ollama: "
self.OLLAMA_BASE_URL = "http://localhost:11434"
self.pipelines = self.get_ollama_models()
pass
def get_ollama_models(self):
r = requests.get(f"{self.OLLAMA_BASE_URL}/api/tags")
models = r.json()
return [
{"id": model["model"], "name": model["name"]} for model in models["models"]
]
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
if "user" in body:
print("######################################")
print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})')
print(f"# Message: {user_message}")
print("######################################")
try:
r = requests.post(
url=f"{self.OLLAMA_BASE_URL}/v1/chat/completions",
json={**body, "model": model_id},
stream=True,
)
r.raise_for_status()
if body["stream"]:
return r.iter_lines()
else:
return r.json()
except Exception as e:
return f"Error: {e}"

View File

@ -20,11 +20,11 @@ class Pipeline:
print(f"on_shutdown:{__name__}")
pass
def get_response(
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
print(f"pipe:{__name__}")
OLLAMA_BASE_URL = "http://localhost:11434"
MODEL = "llama3"

View File

@ -17,7 +17,7 @@ class OpenAIChatCompletionForm(BaseModel):
model_config = ConfigDict(extra="allow")
class ValveForm(BaseModel):
class FilterForm(BaseModel):
model: str
body: dict
user: Optional[dict] = None