mirror of
https://github.com/open-webui/open-webui
synced 2025-03-27 07:50:37 +00:00
feat: action function
This commit is contained in:
parent
90c3d68f00
commit
eb10001eb7
backend
src/lib/components
chat/Messages
icons
workspace
@ -167,6 +167,15 @@ class FunctionsTable:
|
|||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_global_action_functions(self) -> List[FunctionModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
return [
|
||||||
|
FunctionModel.model_validate(function)
|
||||||
|
for function in db.query(Function)
|
||||||
|
.filter_by(type="action", is_active=True, is_global=True)
|
||||||
|
.all()
|
||||||
|
]
|
||||||
|
|
||||||
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
||||||
|
@ -79,6 +79,8 @@ def load_function_module_by_id(function_id):
|
|||||||
return module.Pipe(), "pipe", frontmatter
|
return module.Pipe(), "pipe", frontmatter
|
||||||
elif hasattr(module, "Filter"):
|
elif hasattr(module, "Filter"):
|
||||||
return module.Filter(), "filter", frontmatter
|
return module.Filter(), "filter", frontmatter
|
||||||
|
elif hasattr(module, "Action"):
|
||||||
|
return module.Action(), "action", frontmatter
|
||||||
else:
|
else:
|
||||||
raise Exception("No Function class found")
|
raise Exception("No Function class found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
155
backend/main.py
155
backend/main.py
@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
|||||||
|
|
||||||
|
|
||||||
async def get_all_models():
|
async def get_all_models():
|
||||||
|
# TODO: Optimize this function
|
||||||
pipe_models = []
|
pipe_models = []
|
||||||
openai_models = []
|
openai_models = []
|
||||||
ollama_models = []
|
ollama_models = []
|
||||||
@ -952,6 +953,14 @@ async def get_all_models():
|
|||||||
|
|
||||||
models = pipe_models + openai_models + ollama_models
|
models = pipe_models + openai_models + ollama_models
|
||||||
|
|
||||||
|
global_action_ids = [
|
||||||
|
function.id for function in Functions.get_global_action_functions()
|
||||||
|
]
|
||||||
|
enabled_action_ids = [
|
||||||
|
function.id
|
||||||
|
for function in Functions.get_functions_by_type("action", active_only=True)
|
||||||
|
]
|
||||||
|
|
||||||
custom_models = Models.get_all_models()
|
custom_models = Models.get_all_models()
|
||||||
for custom_model in custom_models:
|
for custom_model in custom_models:
|
||||||
if custom_model.base_model_id == None:
|
if custom_model.base_model_id == None:
|
||||||
@ -962,9 +971,32 @@ async def get_all_models():
|
|||||||
):
|
):
|
||||||
model["name"] = custom_model.name
|
model["name"] = custom_model.name
|
||||||
model["info"] = custom_model.model_dump()
|
model["info"] = custom_model.model_dump()
|
||||||
|
|
||||||
|
action_ids = [] + global_action_ids
|
||||||
|
if "info" in model and "meta" in model["info"]:
|
||||||
|
action_ids.extend(model["info"]["meta"].get("actionIds", []))
|
||||||
|
action_ids = list(set(action_ids))
|
||||||
|
action_ids = [
|
||||||
|
action_id
|
||||||
|
for action_id in action_ids
|
||||||
|
if action_id in enabled_action_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
model["actions"] = [
|
||||||
|
{
|
||||||
|
"id": action_id,
|
||||||
|
"name": Functions.get_function_by_id(action_id).name,
|
||||||
|
"description": Functions.get_function_by_id(
|
||||||
|
action_id
|
||||||
|
).meta.description,
|
||||||
|
}
|
||||||
|
for action_id in action_ids
|
||||||
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
owned_by = "openai"
|
owned_by = "openai"
|
||||||
pipe = None
|
pipe = None
|
||||||
|
actions = []
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if (
|
if (
|
||||||
@ -974,6 +1006,27 @@ async def get_all_models():
|
|||||||
owned_by = model["owned_by"]
|
owned_by = model["owned_by"]
|
||||||
if "pipe" in model:
|
if "pipe" in model:
|
||||||
pipe = model["pipe"]
|
pipe = model["pipe"]
|
||||||
|
|
||||||
|
action_ids = [] + global_action_ids
|
||||||
|
if "info" in model and "meta" in model["info"]:
|
||||||
|
action_ids.extend(model["info"]["meta"].get("actionIds", []))
|
||||||
|
action_ids = list(set(action_ids))
|
||||||
|
action_ids = [
|
||||||
|
action_id
|
||||||
|
for action_id in action_ids
|
||||||
|
if action_id in enabled_action_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"id": action_id,
|
||||||
|
"name": Functions.get_function_by_id(action_id).name,
|
||||||
|
"description": Functions.get_function_by_id(
|
||||||
|
action_id
|
||||||
|
).meta.description,
|
||||||
|
}
|
||||||
|
for action_id in action_ids
|
||||||
|
]
|
||||||
break
|
break
|
||||||
|
|
||||||
models.append(
|
models.append(
|
||||||
@ -986,6 +1039,7 @@ async def get_all_models():
|
|||||||
"info": custom_model.model_dump(),
|
"info": custom_model.model_dump(),
|
||||||
"preset": True,
|
"preset": True,
|
||||||
**({"pipe": pipe} if pipe is not None else {}),
|
**({"pipe": pipe} if pipe is not None else {}),
|
||||||
|
"actions": actions,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/chat/actions/{action_id}")
|
||||||
|
async def chat_completed(
|
||||||
|
action_id: str, form_data: dict, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
action = Functions.get_function_by_id(action_id)
|
||||||
|
if not action:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Action not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = form_data
|
||||||
|
model_id = data["model"]
|
||||||
|
if model_id not in app.state.MODELS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
|
__event_emitter__ = await get_event_emitter(
|
||||||
|
{
|
||||||
|
"chat_id": data["chat_id"],
|
||||||
|
"message_id": data["id"],
|
||||||
|
"session_id": data["session_id"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
__event_call__ = await get_event_call(
|
||||||
|
{
|
||||||
|
"chat_id": data["chat_id"],
|
||||||
|
"message_id": data["id"],
|
||||||
|
"session_id": data["session_id"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_id in webui_app.state.FUNCTIONS:
|
||||||
|
function_module = webui_app.state.FUNCTIONS[action_id]
|
||||||
|
else:
|
||||||
|
function_module, _, _ = load_function_module_by_id(action_id)
|
||||||
|
webui_app.state.FUNCTIONS[action_id] = function_module
|
||||||
|
|
||||||
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
|
valves = Functions.get_function_valves_by_id(action_id)
|
||||||
|
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||||
|
|
||||||
|
if hasattr(function_module, "action"):
|
||||||
|
try:
|
||||||
|
action = function_module.action
|
||||||
|
|
||||||
|
# Get the signature of the function
|
||||||
|
sig = inspect.signature(action)
|
||||||
|
params = {"body": data}
|
||||||
|
|
||||||
|
# Extra parameters to be passed to the function
|
||||||
|
extra_params = {
|
||||||
|
"__model__": model,
|
||||||
|
"__id__": action_id,
|
||||||
|
"__event_emitter__": __event_emitter__,
|
||||||
|
"__event_call__": __event_call__,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add extra params in contained in function signature
|
||||||
|
for key, value in extra_params.items():
|
||||||
|
if key in sig.parameters:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
if "__user__" in sig.parameters:
|
||||||
|
__user__ = {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(function_module, "UserValves"):
|
||||||
|
__user__["valves"] = function_module.UserValves(
|
||||||
|
**Functions.get_user_valves_by_id_and_user_id(
|
||||||
|
action_id, user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
params = {**params, "__user__": __user__}
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(action):
|
||||||
|
data = await action(**params)
|
||||||
|
else:
|
||||||
|
data = action(**params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={"detail": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
#
|
#
|
||||||
# Task Endpoints
|
# Task Endpoints
|
||||||
|
@ -37,6 +37,7 @@
|
|||||||
import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
|
import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
|
||||||
import Spinner from '$lib/components/common/Spinner.svelte';
|
import Spinner from '$lib/components/common/Spinner.svelte';
|
||||||
import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
|
import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
|
||||||
|
import Sparkles from '$lib/components/icons/Sparkles.svelte';
|
||||||
|
|
||||||
export let message;
|
export let message;
|
||||||
export let siblings;
|
export let siblings;
|
||||||
@ -1020,6 +1021,22 @@
|
|||||||
</svg>
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
|
||||||
|
{#each model?.actions ?? [] as action}
|
||||||
|
<Tooltip content={action.name} placement="bottom">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="{isLastMessage
|
||||||
|
? 'visible'
|
||||||
|
: 'invisible group-hover:visible'} p-1.5 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
|
||||||
|
on:click={() => {
|
||||||
|
console.log('action');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Sparkles strokeWidth="2.1" className="size-4" />
|
||||||
|
</button>
|
||||||
|
</Tooltip>
|
||||||
|
{/each}
|
||||||
{/if}
|
{/if}
|
||||||
{/if}
|
{/if}
|
||||||
{/if}
|
{/if}
|
||||||
|
19
src/lib/components/icons/Sparkles.svelte
Normal file
19
src/lib/components/icons/Sparkles.svelte
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
export let className = 'w-4 h-4';
|
||||||
|
export let strokeWidth = '1.5';
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke-width={strokeWidth}
|
||||||
|
stroke="currentColor"
|
||||||
|
class={className}
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d="M9.813 15.904 9 18.75l-.813-2.846a4.5 4.5 0 0 0-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 0 0 3.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 0 0 3.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 0 0-3.09 3.09ZM18.259 8.715 18 9.75l-.259-1.035a3.375 3.375 0 0 0-2.455-2.456L14.25 6l1.036-.259a3.375 3.375 0 0 0 2.455-2.456L18 2.25l.259 1.035a3.375 3.375 0 0 0 2.456 2.456L21.75 6l-1.035.259a3.375 3.375 0 0 0-2.456 2.456ZM16.894 20.567 16.5 21.75l-.394-1.183a2.25 2.25 0 0 0-1.423-1.423L13.5 18.75l1.183-.394a2.25 2.25 0 0 0 1.423-1.423l.394-1.183.394 1.183a2.25 2.25 0 0 0 1.423 1.423l1.183.394-1.183.394a2.25 2.25 0 0 0-1.423 1.423Z"
|
||||||
|
/>
|
||||||
|
</svg>
|
@ -122,12 +122,17 @@
|
|||||||
|
|
||||||
if (res) {
|
if (res) {
|
||||||
if (func.is_global) {
|
if (func.is_global) {
|
||||||
toast.success($i18n.t('Filter is now globally enabled'));
|
func.type === 'filter'
|
||||||
|
? toast.success($i18n.t('Filter is now globally enabled'))
|
||||||
|
: toast.success($i18n.t('Function is now globally enabled'));
|
||||||
} else {
|
} else {
|
||||||
toast.success($i18n.t('Filter is now globally disabled'));
|
func.type === 'filter'
|
||||||
|
? toast.success($i18n.t('Filter is now globally disabled'))
|
||||||
|
: toast.success($i18n.t('Function is now globally disabled'));
|
||||||
}
|
}
|
||||||
|
|
||||||
functions.set(await getFunctions(localStorage.token));
|
functions.set(await getFunctions(localStorage.token));
|
||||||
|
models.set(await getModels(localStorage.token));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
</script>
|
</script>
|
||||||
@ -294,7 +299,7 @@
|
|||||||
showDeleteConfirm = true;
|
showDeleteConfirm = true;
|
||||||
}}
|
}}
|
||||||
toggleGlobalHandler={() => {
|
toggleGlobalHandler={() => {
|
||||||
if (func.type === 'filter') {
|
if (['filter', 'action'].includes(func.type)) {
|
||||||
toggleGlobalHandler(func);
|
toggleGlobalHandler(func);
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
|
@ -48,7 +48,7 @@
|
|||||||
align="start"
|
align="start"
|
||||||
transition={flyAndScale}
|
transition={flyAndScale}
|
||||||
>
|
>
|
||||||
{#if func.type === 'filter'}
|
{#if ['filter', 'action'].includes(func.type)}
|
||||||
<div
|
<div
|
||||||
class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointerrounded-md"
|
class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointerrounded-md"
|
||||||
>
|
>
|
||||||
|
Loading…
Reference in New Issue
Block a user