From c4bd60114eb6996f650d5ce0ab4b3ac0f41d6606 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 02:30:00 -0700 Subject: [PATCH] feat: filter inlet support --- backend/main.py | 67 +++++++++++++++---- src/lib/components/chat/Chat.svelte | 5 +- .../workspace/Models/FiltersSelector.svelte | 1 + 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/backend/main.py b/backend/main.py index febda4ced..951bf9654 100644 --- a/backend/main.py +++ b/backend/main.py @@ -50,7 +50,9 @@ from typing import List, Optional from apps.webui.models.models import Models, ModelModel from apps.webui.models.tools import Tools -from apps.webui.utils import load_toolkit_module_by_id +from apps.webui.models.functions import Functions + +from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from utils.utils import ( @@ -318,9 +320,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): data_items = [] - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path + if request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] ): log.debug(f"request.url.path: {request.url.path}") @@ -328,23 +330,62 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): body = await request.body() body_str = body.decode("utf-8") data = json.loads(body_str) if body_str else {} - - model_id = data["model"] user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), ) - # Set the task model - task_model_id = model_id - if task_model_id not in app.state.MODELS: + # Flag to skip RAG completions if file_handler is present in tools/functions + skip_files = False + + model_id = data["model"] + if model_id not in app.state.MODELS: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) + model = app.state.MODELS[model_id] - # Check if the user has a custom task model - # If the user has a custom task model, use that model + print(":", data) + + # Check if the model has any filters + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if getattr(function_module, "file_handler"): + skip_files = True + + try: + if hasattr(function_module, "inlet"): + data = function_module.inlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=e, + ) + + print("Filtered:", data) + # Set the task model + task_model_id = data["model"] + # Check if the user has a custom task model and use that model if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if ( app.state.config.TASK_MODEL @@ -358,7 +399,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL - skip_files = False prompt = get_last_user_message(data["messages"]) context = "" @@ -409,8 +449,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): log.debug(f"rag_context: {rag_context}, citations: {citations}") - if citations: + if citations and data.get("citations"): data_items.append({"citations": citations}) + del data["citations"] del data["files"] diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index b33b26fa3..1fae82415 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -630,7 +630,7 @@ keep_alive: $settings.keepAlive ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0, + citations: files.length > 0 ? true : undefined, chat_id: $chatId }); @@ -928,7 +928,8 @@ max_tokens: $settings?.params?.max_tokens ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0, + citations: files.length > 0 ? true : undefined, + chat_id: $chatId }, `${OPENAI_API_BASE_URL}` diff --git a/src/lib/components/workspace/Models/FiltersSelector.svelte b/src/lib/components/workspace/Models/FiltersSelector.svelte index 291bb8939..92f64c2cf 100644 --- a/src/lib/components/workspace/Models/FiltersSelector.svelte +++ b/src/lib/components/workspace/Models/FiltersSelector.svelte @@ -31,6 +31,7 @@ {$i18n.t('To select filters here, add them to the "Functions" workspace first.')} +
{#if filters.length > 0}