From 3a629ffe0009cf3cbceccc6af53f0da03cbeb9c2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 18:39:27 -0700 Subject: [PATCH] feat: global filter --- backend/main.py | 112 ++++++++++-------- src/lib/components/workspace/Functions.svelte | 3 +- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/backend/main.py b/backend/main.py index 02552f209..2a44d2029 100644 --- a/backend/main.py +++ b/backend/main.py @@ -376,70 +376,77 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) model = app.state.MODELS[model_id] + filter_ids = [ + function.id + for function in Functions.get_functions_by_type( + "filter", active_only=True + ) + ] # Check if the model has any filters if "info" in model and "meta" in model["info"]: - for filter_id in model["info"]["meta"].get("filterIds", []): - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type = load_function_module_by_id( - filter_id - ) - webui_app.state.FUNCTIONS[filter_id] = function_module + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module - try: - if hasattr(function_module, "inlet"): - inlet = function_module.inlet + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": data} + try: + if hasattr(function_module, "inlet"): + inlet = function_module.inlet - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + # Get the signature of the function + sig = inspect.signature(inlet) + params = {"body": data} - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = ( - function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id ) - except Exception as e: - print(e) + ) + except Exception as e: + print(e) - params = {**params, "__user__": __user__} + params = {**params, "__user__": __user__} - if "__id__" in sig.parameters: - params = { - **params, - "__id__": filter_id, - } + if "__id__" in sig.parameters: + params = { + **params, + "__id__": filter_id, + } - if inspect.iscoroutinefunction(inlet): - data = await inlet(**params) - else: - data = inlet(**params) + if inspect.iscoroutinefunction(inlet): + data = await inlet(**params) + else: + data = inlet(**params) - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) # Set the task model task_model_id = data["model"] @@ -863,6 +870,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = model.get("pipe") if pipe: + async def job(): pipe_id = form_data["model"] if "." in pipe_id: diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index fc8e7a451..75e0ce4ff 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -227,8 +227,9 @@
{ + on:change={async (e) => { toggleFunctionById(localStorage.token, func.id); + models.set(await getModels(localStorage.token)); }} />