feat: filter func outlet

This commit is contained in:
Timothy J. Baek 2024-06-20 03:23:50 -07:00
parent 3101ff143b
commit afd270523c
2 changed files with 45 additions and 12 deletions

View File

@ -474,10 +474,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
], ],
] ]
response = await call_next(request) response = await call_next(request)
# If there are data_items to inject into the response
if len(data_items) > 0:
if isinstance(response, StreamingResponse): if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line # If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type") content_type = response.headers.get("Content-Type")
@ -489,7 +486,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
return StreamingResponse( return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, data_items), self.ollama_stream_wrapper(response.body_iterator, data_items),
) )
else:
return response
# If it's not a chat completion request, just pass it through
response = await call_next(request)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
@ -800,6 +801,12 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
async def chat_completed(form_data: dict, user=Depends(get_verified_user)): async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data data = form_data
model_id = data["model"] model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
filters = [ filters = [
model model
@ -815,14 +822,10 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
) )
) )
] ]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
if "pipeline" in model:
print(model_id) sorted_filters = [model] + sorted_filters
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters: for filter in sorted_filters:
r = None r = None
@ -863,6 +866,34 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else: else:
pass pass
# Check if the model has any filters
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:
if hasattr(function_module, "outlet"):
data = function_module.outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data return data

View File

@ -278,7 +278,9 @@
})), })),
chat_id: $chatId chat_id: $chatId
}).catch((error) => { }).catch((error) => {
console.error(error); toast.error(error);
messages.at(-1).error = { content: error };
return null; return null;
}); });