From 8b998701896ef6e1b67ac76fb08392e019751b97 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 20:11:08 -0700 Subject: [PATCH] enh: filter function priority valve support --- backend/main.py | 131 ++++++++++++++++++++++++++++-------------------- 1 file changed, 77 insertions(+), 54 deletions(-) diff --git a/backend/main.py b/backend/main.py index 2f2e549bd..85fe16d3b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -389,6 +389,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) model = app.state.MODELS[model_id] + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + return (function.valves if function.valves else {}).get( + "priority", 0 + ) + return 0 + filter_ids = [ function.id for function in Functions.get_functions_by_type( @@ -400,6 +408,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids = list(set(filter_ids)) + filter_ids.sort(key=get_priority) for filter_id in filter_ids: filter = Functions.get_function_by_id(filter_id) if filter: @@ -1122,72 +1131,86 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): else: pass + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] # Check if the model has any filters if "info" in model and "meta" in model["info"]: - for filter_id in model["info"]["meta"].get("filterIds", []): - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type = load_function_module_by_id( - filter_id - ) - webui_app.state.FUNCTIONS[filter_id] = function_module + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) - if hasattr(function_module, "valves") and hasattr( - function_module, "Valves" - ): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + # Sort filter_ids by priority, using the get_priority function + filter_ids.sort(key=get_priority) - try: - if hasattr(function_module, "outlet"): - outlet = function_module.outlet + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} + if hasattr(function_module, "valves") and hasattr( + function_module, "Valves" + ): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + try: + if hasattr(function_module, "outlet"): + outlet = function_module.outlet - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) + # Get the signature of the function + sig = inspect.signature(outlet) + params = {"body": data} + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id ) - except Exception as e: - print(e) + ) + except Exception as e: + print(e) - params = {**params, "__user__": __user__} + params = {**params, "__user__": __user__} - if "__id__" in sig.parameters: - params = { - **params, - "__id__": filter_id, - } + if "__id__" in sig.parameters: + params = { + **params, + "__id__": filter_id, + } - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) + if inspect.iscoroutinefunction(outlet): + data = await outlet(**params) + else: + data = outlet(**params) - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) return data