From eb10001eb71f8aa9cb741c700de7f6f954c19270 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 11 Jul 2024 18:41:00 -0700 Subject: [PATCH] feat: action function --- backend/apps/webui/models/functions.py | 9 + backend/apps/webui/utils.py | 2 + backend/main.py | 155 ++++++++++++++++++ .../chat/Messages/ResponseMessage.svelte | 17 ++ src/lib/components/icons/Sparkles.svelte | 19 +++ src/lib/components/workspace/Functions.svelte | 11 +- .../workspace/Functions/FunctionMenu.svelte | 2 +- 7 files changed, 211 insertions(+), 4 deletions(-) create mode 100644 src/lib/components/icons/Sparkles.svelte diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 907576b80..cb73da694 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -167,6 +167,15 @@ class FunctionsTable: .all() ] + def get_global_action_functions(self) -> List[FunctionModel]: + with get_db() as db: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type="action", is_active=True, is_global=True) + .all() + ] + def get_function_valves_by_id(self, id: str) -> Optional[dict]: with get_db() as db: diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index 545120835..96d2b29eb 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -79,6 +79,8 @@ def load_function_module_by_id(function_id): return module.Pipe(), "pipe", frontmatter elif hasattr(module, "Filter"): return module.Filter(), "filter", frontmatter + elif hasattr(module, "Action"): + return module.Action(), "action", frontmatter else: raise Exception("No Function class found") except Exception as e: diff --git a/backend/main.py b/backend/main.py index 01c2fde2a..aa0b6c956 100644 --- a/backend/main.py +++ b/backend/main.py @@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION async def get_all_models(): + # TODO: Optimize this function pipe_models = [] openai_models = [] ollama_models = [] @@ -952,6 +953,14 @@ async def get_all_models(): models = pipe_models + openai_models + ollama_models + global_action_ids = [ + function.id for function in Functions.get_global_action_functions() + ] + enabled_action_ids = [ + function.id + for function in Functions.get_functions_by_type("action", active_only=True) + ] + custom_models = Models.get_all_models() for custom_model in custom_models: if custom_model.base_model_id == None: @@ -962,9 +971,32 @@ async def get_all_models(): ): model["name"] = custom_model.name model["info"] = custom_model.model_dump() + + action_ids = [] + global_action_ids + if "info" in model and "meta" in model["info"]: + action_ids.extend(model["info"]["meta"].get("actionIds", [])) + action_ids = list(set(action_ids)) + action_ids = [ + action_id + for action_id in action_ids + if action_id in enabled_action_ids + ] + + model["actions"] = [ + { + "id": action_id, + "name": Functions.get_function_by_id(action_id).name, + "description": Functions.get_function_by_id( + action_id + ).meta.description, + } + for action_id in action_ids + ] + else: owned_by = "openai" pipe = None + actions = [] for model in models: if ( @@ -974,6 +1006,27 @@ async def get_all_models(): owned_by = model["owned_by"] if "pipe" in model: pipe = model["pipe"] + + action_ids = [] + global_action_ids + if "info" in model and "meta" in model["info"]: + action_ids.extend(model["info"]["meta"].get("actionIds", [])) + action_ids = list(set(action_ids)) + action_ids = [ + action_id + for action_id in action_ids + if action_id in enabled_action_ids + ] + + actions = [ + { + "id": action_id, + "name": Functions.get_function_by_id(action_id).name, + "description": Functions.get_function_by_id( + action_id + ).meta.description, + } + for action_id in action_ids + ] break models.append( @@ -986,6 +1039,7 @@ async def get_all_models(): "info": custom_model.model_dump(), "preset": True, **({"pipe": pipe} if pipe is not None else {}), + "actions": actions, } ) @@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): return data +@app.post("/api/chat/actions/{action_id}") +async def chat_completed( + action_id: str, form_data: dict, user=Depends(get_verified_user) +): + action = Functions.get_function_by_id(action_id) + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Action not found", + ) + + data = form_data + model_id = data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = app.state.MODELS[model_id] + + __event_emitter__ = await get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = await get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + webui_app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + ################################## # # Task Endpoints diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 69c387401..cb938dd52 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -37,6 +37,7 @@ import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; import WebSearchResults from './ResponseMessage/WebSearchResults.svelte'; + import Sparkles from '$lib/components/icons/Sparkles.svelte'; export let message; export let siblings; @@ -1020,6 +1021,22 @@ + + {#each model?.actions ?? [] as action} + + + + {/each} {/if} {/if} {/if} diff --git a/src/lib/components/icons/Sparkles.svelte b/src/lib/components/icons/Sparkles.svelte new file mode 100644 index 000000000..0f9034d26 --- /dev/null +++ b/src/lib/components/icons/Sparkles.svelte @@ -0,0 +1,19 @@ + + + + + diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index 84ad65853..6fc730bc6 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -122,12 +122,17 @@ if (res) { if (func.is_global) { - toast.success($i18n.t('Filter is now globally enabled')); + func.type === 'filter' + ? toast.success($i18n.t('Filter is now globally enabled')) + : toast.success($i18n.t('Function is now globally enabled')); } else { - toast.success($i18n.t('Filter is now globally disabled')); + func.type === 'filter' + ? toast.success($i18n.t('Filter is now globally disabled')) + : toast.success($i18n.t('Function is now globally disabled')); } functions.set(await getFunctions(localStorage.token)); + models.set(await getModels(localStorage.token)); } }; @@ -294,7 +299,7 @@ showDeleteConfirm = true; }} toggleGlobalHandler={() => { - if (func.type === 'filter') { + if (['filter', 'action'].includes(func.type)) { toggleGlobalHandler(func); } }} diff --git a/src/lib/components/workspace/Functions/FunctionMenu.svelte b/src/lib/components/workspace/Functions/FunctionMenu.svelte index ca76ed222..ad82e4a5f 100644 --- a/src/lib/components/workspace/Functions/FunctionMenu.svelte +++ b/src/lib/components/workspace/Functions/FunctionMenu.svelte @@ -48,7 +48,7 @@ align="start" transition={flyAndScale} > - {#if func.type === 'filter'} + {#if ['filter', 'action'].includes(func.type)}