mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: action function
This commit is contained in:
		
							parent
							
								
									90c3d68f00
								
							
						
					
					
						commit
						eb10001eb7
					
				@ -167,6 +167,15 @@ class FunctionsTable:
 | 
			
		||||
                .all()
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    def get_global_action_functions(self) -> List[FunctionModel]:
 | 
			
		||||
        with get_db() as db:
 | 
			
		||||
            return [
 | 
			
		||||
                FunctionModel.model_validate(function)
 | 
			
		||||
                for function in db.query(Function)
 | 
			
		||||
                .filter_by(type="action", is_active=True, is_global=True)
 | 
			
		||||
                .all()
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    def get_function_valves_by_id(self, id: str) -> Optional[dict]:
 | 
			
		||||
        with get_db() as db:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -79,6 +79,8 @@ def load_function_module_by_id(function_id):
 | 
			
		||||
            return module.Pipe(), "pipe", frontmatter
 | 
			
		||||
        elif hasattr(module, "Filter"):
 | 
			
		||||
            return module.Filter(), "filter", frontmatter
 | 
			
		||||
        elif hasattr(module, "Action"):
 | 
			
		||||
            return module.Action(), "action", frontmatter
 | 
			
		||||
        else:
 | 
			
		||||
            raise Exception("No Function class found")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										155
									
								
								backend/main.py
									
									
									
									
									
								
							
							
						
						
									
										155
									
								
								backend/main.py
									
									
									
									
									
								
							@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_all_models():
 | 
			
		||||
    # TODO: Optimize this function
 | 
			
		||||
    pipe_models = []
 | 
			
		||||
    openai_models = []
 | 
			
		||||
    ollama_models = []
 | 
			
		||||
@ -952,6 +953,14 @@ async def get_all_models():
 | 
			
		||||
 | 
			
		||||
    models = pipe_models + openai_models + ollama_models
 | 
			
		||||
 | 
			
		||||
    global_action_ids = [
 | 
			
		||||
        function.id for function in Functions.get_global_action_functions()
 | 
			
		||||
    ]
 | 
			
		||||
    enabled_action_ids = [
 | 
			
		||||
        function.id
 | 
			
		||||
        for function in Functions.get_functions_by_type("action", active_only=True)
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    custom_models = Models.get_all_models()
 | 
			
		||||
    for custom_model in custom_models:
 | 
			
		||||
        if custom_model.base_model_id == None:
 | 
			
		||||
@ -962,9 +971,32 @@ async def get_all_models():
 | 
			
		||||
                ):
 | 
			
		||||
                    model["name"] = custom_model.name
 | 
			
		||||
                    model["info"] = custom_model.model_dump()
 | 
			
		||||
 | 
			
		||||
                    action_ids = [] + global_action_ids
 | 
			
		||||
                    if "info" in model and "meta" in model["info"]:
 | 
			
		||||
                        action_ids.extend(model["info"]["meta"].get("actionIds", []))
 | 
			
		||||
                        action_ids = list(set(action_ids))
 | 
			
		||||
                    action_ids = [
 | 
			
		||||
                        action_id
 | 
			
		||||
                        for action_id in action_ids
 | 
			
		||||
                        if action_id in enabled_action_ids
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                    model["actions"] = [
 | 
			
		||||
                        {
 | 
			
		||||
                            "id": action_id,
 | 
			
		||||
                            "name": Functions.get_function_by_id(action_id).name,
 | 
			
		||||
                            "description": Functions.get_function_by_id(
 | 
			
		||||
                                action_id
 | 
			
		||||
                            ).meta.description,
 | 
			
		||||
                        }
 | 
			
		||||
                        for action_id in action_ids
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            owned_by = "openai"
 | 
			
		||||
            pipe = None
 | 
			
		||||
            actions = []
 | 
			
		||||
 | 
			
		||||
            for model in models:
 | 
			
		||||
                if (
 | 
			
		||||
@ -974,6 +1006,27 @@ async def get_all_models():
 | 
			
		||||
                    owned_by = model["owned_by"]
 | 
			
		||||
                    if "pipe" in model:
 | 
			
		||||
                        pipe = model["pipe"]
 | 
			
		||||
 | 
			
		||||
                    action_ids = [] + global_action_ids
 | 
			
		||||
                    if "info" in model and "meta" in model["info"]:
 | 
			
		||||
                        action_ids.extend(model["info"]["meta"].get("actionIds", []))
 | 
			
		||||
                        action_ids = list(set(action_ids))
 | 
			
		||||
                    action_ids = [
 | 
			
		||||
                        action_id
 | 
			
		||||
                        for action_id in action_ids
 | 
			
		||||
                        if action_id in enabled_action_ids
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                    actions = [
 | 
			
		||||
                        {
 | 
			
		||||
                            "id": action_id,
 | 
			
		||||
                            "name": Functions.get_function_by_id(action_id).name,
 | 
			
		||||
                            "description": Functions.get_function_by_id(
 | 
			
		||||
                                action_id
 | 
			
		||||
                            ).meta.description,
 | 
			
		||||
                        }
 | 
			
		||||
                        for action_id in action_ids
 | 
			
		||||
                    ]
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
            models.append(
 | 
			
		||||
@ -986,6 +1039,7 @@ async def get_all_models():
 | 
			
		||||
                    "info": custom_model.model_dump(),
 | 
			
		||||
                    "preset": True,
 | 
			
		||||
                    **({"pipe": pipe} if pipe is not None else {}),
 | 
			
		||||
                    "actions": actions,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
 | 
			
		||||
    return data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/api/chat/actions/{action_id}")
 | 
			
		||||
async def chat_completed(
 | 
			
		||||
    action_id: str, form_data: dict, user=Depends(get_verified_user)
 | 
			
		||||
):
 | 
			
		||||
    action = Functions.get_function_by_id(action_id)
 | 
			
		||||
    if not action:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_404_NOT_FOUND,
 | 
			
		||||
            detail="Action not found",
 | 
			
		||||
        )
 | 
			
		||||
    
 | 
			
		||||
    data = form_data
 | 
			
		||||
    model_id = data["model"]
 | 
			
		||||
    if model_id not in app.state.MODELS:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_404_NOT_FOUND,
 | 
			
		||||
            detail="Model not found",
 | 
			
		||||
        )
 | 
			
		||||
    model = app.state.MODELS[model_id]
 | 
			
		||||
 | 
			
		||||
    __event_emitter__ = await get_event_emitter(
 | 
			
		||||
        {
 | 
			
		||||
            "chat_id": data["chat_id"],
 | 
			
		||||
            "message_id": data["id"],
 | 
			
		||||
            "session_id": data["session_id"],
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
    __event_call__ = await get_event_call(
 | 
			
		||||
        {
 | 
			
		||||
            "chat_id": data["chat_id"],
 | 
			
		||||
            "message_id": data["id"],
 | 
			
		||||
            "session_id": data["session_id"],
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if action_id in webui_app.state.FUNCTIONS:
 | 
			
		||||
        function_module = webui_app.state.FUNCTIONS[action_id]
 | 
			
		||||
    else:
 | 
			
		||||
        function_module, _, _ = load_function_module_by_id(action_id)
 | 
			
		||||
        webui_app.state.FUNCTIONS[action_id] = function_module
 | 
			
		||||
 | 
			
		||||
    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
 | 
			
		||||
        valves = Functions.get_function_valves_by_id(action_id)
 | 
			
		||||
        function_module.valves = function_module.Valves(**(valves if valves else {}))
 | 
			
		||||
 | 
			
		||||
    if hasattr(function_module, "action"):
 | 
			
		||||
        try:
 | 
			
		||||
            action = function_module.action
 | 
			
		||||
 | 
			
		||||
            # Get the signature of the function
 | 
			
		||||
            sig = inspect.signature(action)
 | 
			
		||||
            params = {"body": data}
 | 
			
		||||
 | 
			
		||||
            # Extra parameters to be passed to the function
 | 
			
		||||
            extra_params = {
 | 
			
		||||
                "__model__": model,
 | 
			
		||||
                "__id__": action_id,
 | 
			
		||||
                "__event_emitter__": __event_emitter__,
 | 
			
		||||
                "__event_call__": __event_call__,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            # Add extra params in contained in function signature
 | 
			
		||||
            for key, value in extra_params.items():
 | 
			
		||||
                if key in sig.parameters:
 | 
			
		||||
                    params[key] = value
 | 
			
		||||
 | 
			
		||||
            if "__user__" in sig.parameters:
 | 
			
		||||
                __user__ = {
 | 
			
		||||
                    "id": user.id,
 | 
			
		||||
                    "email": user.email,
 | 
			
		||||
                    "name": user.name,
 | 
			
		||||
                    "role": user.role,
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    if hasattr(function_module, "UserValves"):
 | 
			
		||||
                        __user__["valves"] = function_module.UserValves(
 | 
			
		||||
                            **Functions.get_user_valves_by_id_and_user_id(
 | 
			
		||||
                                action_id, user.id
 | 
			
		||||
                            )
 | 
			
		||||
                        )
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    print(e)
 | 
			
		||||
 | 
			
		||||
                params = {**params, "__user__": __user__}
 | 
			
		||||
 | 
			
		||||
            if inspect.iscoroutinefunction(action):
 | 
			
		||||
                data = await action(**params)
 | 
			
		||||
            else:
 | 
			
		||||
                data = action(**params)
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {e}")
 | 
			
		||||
            return JSONResponse(
 | 
			
		||||
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                content={"detail": str(e)},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    return data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
##################################
 | 
			
		||||
#
 | 
			
		||||
# Task Endpoints
 | 
			
		||||
 | 
			
		||||
@ -37,6 +37,7 @@
 | 
			
		||||
	import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
 | 
			
		||||
	import Spinner from '$lib/components/common/Spinner.svelte';
 | 
			
		||||
	import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
 | 
			
		||||
	import Sparkles from '$lib/components/icons/Sparkles.svelte';
 | 
			
		||||
 | 
			
		||||
	export let message;
 | 
			
		||||
	export let siblings;
 | 
			
		||||
@ -1020,6 +1021,22 @@
 | 
			
		||||
														</svg>
 | 
			
		||||
													</button>
 | 
			
		||||
												</Tooltip>
 | 
			
		||||
 | 
			
		||||
												{#each model?.actions ?? [] as action}
 | 
			
		||||
													<Tooltip content={action.name} placement="bottom">
 | 
			
		||||
														<button
 | 
			
		||||
															type="button"
 | 
			
		||||
															class="{isLastMessage
 | 
			
		||||
																? 'visible'
 | 
			
		||||
																: 'invisible group-hover:visible'} p-1.5 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
 | 
			
		||||
															on:click={() => {
 | 
			
		||||
																console.log('action');
 | 
			
		||||
															}}
 | 
			
		||||
														>
 | 
			
		||||
															<Sparkles strokeWidth="2.1" className="size-4" />
 | 
			
		||||
														</button>
 | 
			
		||||
													</Tooltip>
 | 
			
		||||
												{/each}
 | 
			
		||||
											{/if}
 | 
			
		||||
										{/if}
 | 
			
		||||
									{/if}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										19
									
								
								src/lib/components/icons/Sparkles.svelte
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								src/lib/components/icons/Sparkles.svelte
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
			
		||||
<script lang="ts">
 | 
			
		||||
	export let className = 'w-4 h-4';
 | 
			
		||||
	export let strokeWidth = '1.5';
 | 
			
		||||
</script>
 | 
			
		||||
 | 
			
		||||
<svg
 | 
			
		||||
	xmlns="http://www.w3.org/2000/svg"
 | 
			
		||||
	fill="none"
 | 
			
		||||
	viewBox="0 0 24 24"
 | 
			
		||||
	stroke-width={strokeWidth}
 | 
			
		||||
	stroke="currentColor"
 | 
			
		||||
	class={className}
 | 
			
		||||
>
 | 
			
		||||
	<path
 | 
			
		||||
		stroke-linecap="round"
 | 
			
		||||
		stroke-linejoin="round"
 | 
			
		||||
		d="M9.813 15.904 9 18.75l-.813-2.846a4.5 4.5 0 0 0-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 0 0 3.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 0 0 3.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 0 0-3.09 3.09ZM18.259 8.715 18 9.75l-.259-1.035a3.375 3.375 0 0 0-2.455-2.456L14.25 6l1.036-.259a3.375 3.375 0 0 0 2.455-2.456L18 2.25l.259 1.035a3.375 3.375 0 0 0 2.456 2.456L21.75 6l-1.035.259a3.375 3.375 0 0 0-2.456 2.456ZM16.894 20.567 16.5 21.75l-.394-1.183a2.25 2.25 0 0 0-1.423-1.423L13.5 18.75l1.183-.394a2.25 2.25 0 0 0 1.423-1.423l.394-1.183.394 1.183a2.25 2.25 0 0 0 1.423 1.423l1.183.394-1.183.394a2.25 2.25 0 0 0-1.423 1.423Z"
 | 
			
		||||
	/>
 | 
			
		||||
</svg>
 | 
			
		||||
@ -122,12 +122,17 @@
 | 
			
		||||
 | 
			
		||||
		if (res) {
 | 
			
		||||
			if (func.is_global) {
 | 
			
		||||
				toast.success($i18n.t('Filter is now globally enabled'));
 | 
			
		||||
				func.type === 'filter'
 | 
			
		||||
					? toast.success($i18n.t('Filter is now globally enabled'))
 | 
			
		||||
					: toast.success($i18n.t('Function is now globally enabled'));
 | 
			
		||||
			} else {
 | 
			
		||||
				toast.success($i18n.t('Filter is now globally disabled'));
 | 
			
		||||
				func.type === 'filter'
 | 
			
		||||
					? toast.success($i18n.t('Filter is now globally disabled'))
 | 
			
		||||
					: toast.success($i18n.t('Function is now globally disabled'));
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			functions.set(await getFunctions(localStorage.token));
 | 
			
		||||
			models.set(await getModels(localStorage.token));
 | 
			
		||||
		}
 | 
			
		||||
	};
 | 
			
		||||
</script>
 | 
			
		||||
@ -294,7 +299,7 @@
 | 
			
		||||
						showDeleteConfirm = true;
 | 
			
		||||
					}}
 | 
			
		||||
					toggleGlobalHandler={() => {
 | 
			
		||||
						if (func.type === 'filter') {
 | 
			
		||||
						if (['filter', 'action'].includes(func.type)) {
 | 
			
		||||
							toggleGlobalHandler(func);
 | 
			
		||||
						}
 | 
			
		||||
					}}
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@
 | 
			
		||||
			align="start"
 | 
			
		||||
			transition={flyAndScale}
 | 
			
		||||
		>
 | 
			
		||||
			{#if func.type === 'filter'}
 | 
			
		||||
			{#if ['filter', 'action'].includes(func.type)}
 | 
			
		||||
				<div
 | 
			
		||||
					class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointerrounded-md"
 | 
			
		||||
				>
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user