From afd270523c20af86138224bb2b20151bdb5984c0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 03:23:50 -0700 Subject: [PATCH] feat: filter func outlet --- backend/main.py | 53 +++++++++++++++++++++++------ src/lib/components/chat/Chat.svelte | 4 ++- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/backend/main.py b/backend/main.py index 5f845877e..dade596a4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -474,10 +474,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ], ] - response = await call_next(request) - - # If there are data_items to inject into the response - if len(data_items) > 0: + response = await call_next(request) if isinstance(response, StreamingResponse): # If it's a streaming response, inject it as SSE event or NDJSON line content_type = response.headers.get("Content-Type") @@ -489,7 +486,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): return StreamingResponse( self.ollama_stream_wrapper(response.body_iterator, data_items), ) + else: + return response + # If it's not a chat completion request, just pass it through + response = await call_next(request) return response async def _receive(self, body: bytes): @@ -800,6 +801,12 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u async def chat_completed(form_data: dict, user=Depends(get_verified_user)): data = form_data model_id = data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = app.state.MODELS[model_id] filters = [ model @@ -815,14 +822,10 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ) ) ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - print(model_id) - - if model_id in app.state.MODELS: - model = app.state.MODELS[model_id] - if "pipeline" in model: - sorted_filters = [model] + sorted_filters + if "pipeline" in model: + sorted_filters = [model] + sorted_filters for filter in sorted_filters: r = None @@ -863,6 +866,34 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): else: pass + # Check if the model has any filters + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module + + try: + if hasattr(function_module, "outlet"): + data = function_module.outlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + return data diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index a60aef51a..9cf2201fc 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -278,7 +278,9 @@ })), chat_id: $chatId }).catch((error) => { - console.error(error); + toast.error(error); + messages.at(-1).error = { content: error }; + return null; });