diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py
index 907576b80..cb73da694 100644
--- a/backend/apps/webui/models/functions.py
+++ b/backend/apps/webui/models/functions.py
@@ -167,6 +167,15 @@ class FunctionsTable:
.all()
]
+ def get_global_action_functions(self) -> List[FunctionModel]:
+ with get_db() as db:
+ return [
+ FunctionModel.model_validate(function)
+ for function in db.query(Function)
+ .filter_by(type="action", is_active=True, is_global=True)
+ .all()
+ ]
+
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db:
diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py
index 545120835..96d2b29eb 100644
--- a/backend/apps/webui/utils.py
+++ b/backend/apps/webui/utils.py
@@ -79,6 +79,8 @@ def load_function_module_by_id(function_id):
return module.Pipe(), "pipe", frontmatter
elif hasattr(module, "Filter"):
return module.Filter(), "filter", frontmatter
+ elif hasattr(module, "Action"):
+ return module.Action(), "action", frontmatter
else:
raise Exception("No Function class found")
except Exception as e:
diff --git a/backend/main.py b/backend/main.py
index 01c2fde2a..aa0b6c956 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models():
+ # TODO: Optimize this function
pipe_models = []
openai_models = []
ollama_models = []
@@ -952,6 +953,14 @@ async def get_all_models():
models = pipe_models + openai_models + ollama_models
+ global_action_ids = [
+ function.id for function in Functions.get_global_action_functions()
+ ]
+ enabled_action_ids = [
+ function.id
+ for function in Functions.get_functions_by_type("action", active_only=True)
+ ]
+
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
@@ -962,9 +971,32 @@ async def get_all_models():
):
model["name"] = custom_model.name
model["info"] = custom_model.model_dump()
+
+ action_ids = [] + global_action_ids
+ if "info" in model and "meta" in model["info"]:
+ action_ids.extend(model["info"]["meta"].get("actionIds", []))
+ action_ids = list(set(action_ids))
+ action_ids = [
+ action_id
+ for action_id in action_ids
+ if action_id in enabled_action_ids
+ ]
+
+ model["actions"] = [
+ {
+ "id": action_id,
+ "name": Functions.get_function_by_id(action_id).name,
+ "description": Functions.get_function_by_id(
+ action_id
+ ).meta.description,
+ }
+ for action_id in action_ids
+ ]
+
else:
owned_by = "openai"
pipe = None
+ actions = []
for model in models:
if (
@@ -974,6 +1006,27 @@ async def get_all_models():
owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
+
+ action_ids = [] + global_action_ids
+ if "info" in model and "meta" in model["info"]:
+ action_ids.extend(model["info"]["meta"].get("actionIds", []))
+ action_ids = list(set(action_ids))
+ action_ids = [
+ action_id
+ for action_id in action_ids
+ if action_id in enabled_action_ids
+ ]
+
+ actions = [
+ {
+ "id": action_id,
+ "name": Functions.get_function_by_id(action_id).name,
+ "description": Functions.get_function_by_id(
+ action_id
+ ).meta.description,
+ }
+ for action_id in action_ids
+ ]
break
models.append(
@@ -986,6 +1039,7 @@ async def get_all_models():
"info": custom_model.model_dump(),
"preset": True,
**({"pipe": pipe} if pipe is not None else {}),
+ "actions": actions,
}
)
@@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
return data
+@app.post("/api/chat/actions/{action_id}")
+async def chat_completed(
+ action_id: str, form_data: dict, user=Depends(get_verified_user)
+):
+ action = Functions.get_function_by_id(action_id)
+ if not action:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Action not found",
+ )
+
+ data = form_data
+ model_id = data["model"]
+ if model_id not in app.state.MODELS:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+ model = app.state.MODELS[model_id]
+
+ __event_emitter__ = await get_event_emitter(
+ {
+ "chat_id": data["chat_id"],
+ "message_id": data["id"],
+ "session_id": data["session_id"],
+ }
+ )
+ __event_call__ = await get_event_call(
+ {
+ "chat_id": data["chat_id"],
+ "message_id": data["id"],
+ "session_id": data["session_id"],
+ }
+ )
+
+ if action_id in webui_app.state.FUNCTIONS:
+ function_module = webui_app.state.FUNCTIONS[action_id]
+ else:
+ function_module, _, _ = load_function_module_by_id(action_id)
+ webui_app.state.FUNCTIONS[action_id] = function_module
+
+ if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+ valves = Functions.get_function_valves_by_id(action_id)
+ function_module.valves = function_module.Valves(**(valves if valves else {}))
+
+ if hasattr(function_module, "action"):
+ try:
+ action = function_module.action
+
+ # Get the signature of the function
+ sig = inspect.signature(action)
+ params = {"body": data}
+
+ # Extra parameters to be passed to the function
+ extra_params = {
+ "__model__": model,
+ "__id__": action_id,
+ "__event_emitter__": __event_emitter__,
+ "__event_call__": __event_call__,
+ }
+
+ # Add extra params in contained in function signature
+ for key, value in extra_params.items():
+ if key in sig.parameters:
+ params[key] = value
+
+ if "__user__" in sig.parameters:
+ __user__ = {
+ "id": user.id,
+ "email": user.email,
+ "name": user.name,
+ "role": user.role,
+ }
+
+ try:
+ if hasattr(function_module, "UserValves"):
+ __user__["valves"] = function_module.UserValves(
+ **Functions.get_user_valves_by_id_and_user_id(
+ action_id, user.id
+ )
+ )
+ except Exception as e:
+ print(e)
+
+ params = {**params, "__user__": __user__}
+
+ if inspect.iscoroutinefunction(action):
+ data = await action(**params)
+ else:
+ data = action(**params)
+
+ except Exception as e:
+ print(f"Error: {e}")
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": str(e)},
+ )
+
+ return data
+
+
##################################
#
# Task Endpoints
diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte
index 69c387401..cb938dd52 100644
--- a/src/lib/components/chat/Messages/ResponseMessage.svelte
+++ b/src/lib/components/chat/Messages/ResponseMessage.svelte
@@ -37,6 +37,7 @@
import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
+ import Sparkles from '$lib/components/icons/Sparkles.svelte';
export let message;
export let siblings;
@@ -1020,6 +1021,22 @@
+
+ {#each model?.actions ?? [] as action}
+