This commit is contained in:
Timothy J. Baek 2024-06-20 01:51:39 -07:00
parent 08cc20cb93
commit 448ca9d836
2 changed files with 132 additions and 89 deletions

View File

@ -170,6 +170,13 @@ app.state.MODELS = {}
origins = ["*"]
##################################
#
# ChatCompletion Middleware
#
##################################
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
):
@ -469,6 +476,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
#
##################################
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
@ -628,7 +641,6 @@ async def update_embedding_function(request: Request, call_next):
app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
@ -730,6 +742,104 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
##################################
#
# Task Endpoints
#
##################################
# TODO: Refactor task API endpoints below into a separate file
@app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)):
return {
@ -1015,92 +1125,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
##################################
#
# Pipelines Endpoints
#
##################################
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
# TODO: Refactor pipelines API endpoints below into a separate file
@app.get("/api/pipelines/list")
@ -1423,6 +1455,13 @@ async def update_pipeline_valves(
)
##################################
#
# Config Endpoints
#
##################################
@app.get("/api/config")
async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
@ -1486,6 +1525,9 @@ async def update_model_filter_config(
}
# TODO: webhook endpoint should be under config endpoints
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {

View File

@ -30,9 +30,10 @@
let boilerplate = `from pydantic import BaseModel
from typing import Optional
class Filter:
class Valves(BaseModel):
max_turns: int
max_turns: int = 4
pass
def __init__(self):
@ -42,14 +43,14 @@ class Filter:
# Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
# which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
self.valves = self.Valves(**{"max_turns": 10})
self.valves = self.Valves(**{"max_turns": 2})
pass
def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify the request body or validate it before processing by the chat completion API.
# This function is the pre-processor for the API where various checks on the input can be performed.
# It can also modify the request before sending it to the API.
print("inlet")
print(body)
print(user)
@ -65,7 +66,7 @@ class Filter:
def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify or analyze the response body after processing by the API.
# This function is the post-processor for the API, which can be used to modify the response
# This function is the post-processor for the API, which can be used to modify the response
# or perform additional checks and analytics.
print(f"outlet")
print(body)