mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: action function
This commit is contained in:
		
							parent
							
								
									90c3d68f00
								
							
						
					
					
						commit
						eb10001eb7
					
				@ -167,6 +167,15 @@ class FunctionsTable:
 | 
				
			|||||||
                .all()
 | 
					                .all()
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_global_action_functions(self) -> List[FunctionModel]:
 | 
				
			||||||
 | 
					        with get_db() as db:
 | 
				
			||||||
 | 
					            return [
 | 
				
			||||||
 | 
					                FunctionModel.model_validate(function)
 | 
				
			||||||
 | 
					                for function in db.query(Function)
 | 
				
			||||||
 | 
					                .filter_by(type="action", is_active=True, is_global=True)
 | 
				
			||||||
 | 
					                .all()
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_function_valves_by_id(self, id: str) -> Optional[dict]:
 | 
					    def get_function_valves_by_id(self, id: str) -> Optional[dict]:
 | 
				
			||||||
        with get_db() as db:
 | 
					        with get_db() as db:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -79,6 +79,8 @@ def load_function_module_by_id(function_id):
 | 
				
			|||||||
            return module.Pipe(), "pipe", frontmatter
 | 
					            return module.Pipe(), "pipe", frontmatter
 | 
				
			||||||
        elif hasattr(module, "Filter"):
 | 
					        elif hasattr(module, "Filter"):
 | 
				
			||||||
            return module.Filter(), "filter", frontmatter
 | 
					            return module.Filter(), "filter", frontmatter
 | 
				
			||||||
 | 
					        elif hasattr(module, "Action"):
 | 
				
			||||||
 | 
					            return module.Action(), "action", frontmatter
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise Exception("No Function class found")
 | 
					            raise Exception("No Function class found")
 | 
				
			||||||
    except Exception as e:
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										155
									
								
								backend/main.py
									
									
									
									
									
								
							
							
						
						
									
										155
									
								
								backend/main.py
									
									
									
									
									
								
							@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_all_models():
 | 
					async def get_all_models():
 | 
				
			||||||
 | 
					    # TODO: Optimize this function
 | 
				
			||||||
    pipe_models = []
 | 
					    pipe_models = []
 | 
				
			||||||
    openai_models = []
 | 
					    openai_models = []
 | 
				
			||||||
    ollama_models = []
 | 
					    ollama_models = []
 | 
				
			||||||
@ -952,6 +953,14 @@ async def get_all_models():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    models = pipe_models + openai_models + ollama_models
 | 
					    models = pipe_models + openai_models + ollama_models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    global_action_ids = [
 | 
				
			||||||
 | 
					        function.id for function in Functions.get_global_action_functions()
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    enabled_action_ids = [
 | 
				
			||||||
 | 
					        function.id
 | 
				
			||||||
 | 
					        for function in Functions.get_functions_by_type("action", active_only=True)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    custom_models = Models.get_all_models()
 | 
					    custom_models = Models.get_all_models()
 | 
				
			||||||
    for custom_model in custom_models:
 | 
					    for custom_model in custom_models:
 | 
				
			||||||
        if custom_model.base_model_id == None:
 | 
					        if custom_model.base_model_id == None:
 | 
				
			||||||
@ -962,9 +971,32 @@ async def get_all_models():
 | 
				
			|||||||
                ):
 | 
					                ):
 | 
				
			||||||
                    model["name"] = custom_model.name
 | 
					                    model["name"] = custom_model.name
 | 
				
			||||||
                    model["info"] = custom_model.model_dump()
 | 
					                    model["info"] = custom_model.model_dump()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    action_ids = [] + global_action_ids
 | 
				
			||||||
 | 
					                    if "info" in model and "meta" in model["info"]:
 | 
				
			||||||
 | 
					                        action_ids.extend(model["info"]["meta"].get("actionIds", []))
 | 
				
			||||||
 | 
					                        action_ids = list(set(action_ids))
 | 
				
			||||||
 | 
					                    action_ids = [
 | 
				
			||||||
 | 
					                        action_id
 | 
				
			||||||
 | 
					                        for action_id in action_ids
 | 
				
			||||||
 | 
					                        if action_id in enabled_action_ids
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    model["actions"] = [
 | 
				
			||||||
 | 
					                        {
 | 
				
			||||||
 | 
					                            "id": action_id,
 | 
				
			||||||
 | 
					                            "name": Functions.get_function_by_id(action_id).name,
 | 
				
			||||||
 | 
					                            "description": Functions.get_function_by_id(
 | 
				
			||||||
 | 
					                                action_id
 | 
				
			||||||
 | 
					                            ).meta.description,
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                        for action_id in action_ids
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            owned_by = "openai"
 | 
					            owned_by = "openai"
 | 
				
			||||||
            pipe = None
 | 
					            pipe = None
 | 
				
			||||||
 | 
					            actions = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for model in models:
 | 
					            for model in models:
 | 
				
			||||||
                if (
 | 
					                if (
 | 
				
			||||||
@ -974,6 +1006,27 @@ async def get_all_models():
 | 
				
			|||||||
                    owned_by = model["owned_by"]
 | 
					                    owned_by = model["owned_by"]
 | 
				
			||||||
                    if "pipe" in model:
 | 
					                    if "pipe" in model:
 | 
				
			||||||
                        pipe = model["pipe"]
 | 
					                        pipe = model["pipe"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    action_ids = [] + global_action_ids
 | 
				
			||||||
 | 
					                    if "info" in model and "meta" in model["info"]:
 | 
				
			||||||
 | 
					                        action_ids.extend(model["info"]["meta"].get("actionIds", []))
 | 
				
			||||||
 | 
					                        action_ids = list(set(action_ids))
 | 
				
			||||||
 | 
					                    action_ids = [
 | 
				
			||||||
 | 
					                        action_id
 | 
				
			||||||
 | 
					                        for action_id in action_ids
 | 
				
			||||||
 | 
					                        if action_id in enabled_action_ids
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    actions = [
 | 
				
			||||||
 | 
					                        {
 | 
				
			||||||
 | 
					                            "id": action_id,
 | 
				
			||||||
 | 
					                            "name": Functions.get_function_by_id(action_id).name,
 | 
				
			||||||
 | 
					                            "description": Functions.get_function_by_id(
 | 
				
			||||||
 | 
					                                action_id
 | 
				
			||||||
 | 
					                            ).meta.description,
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                        for action_id in action_ids
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            models.append(
 | 
					            models.append(
 | 
				
			||||||
@ -986,6 +1039,7 @@ async def get_all_models():
 | 
				
			|||||||
                    "info": custom_model.model_dump(),
 | 
					                    "info": custom_model.model_dump(),
 | 
				
			||||||
                    "preset": True,
 | 
					                    "preset": True,
 | 
				
			||||||
                    **({"pipe": pipe} if pipe is not None else {}),
 | 
					                    **({"pipe": pipe} if pipe is not None else {}),
 | 
				
			||||||
 | 
					                    "actions": actions,
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
 | 
				
			|||||||
    return data
 | 
					    return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/api/chat/actions/{action_id}")
 | 
				
			||||||
 | 
					async def chat_completed(
 | 
				
			||||||
 | 
					    action_id: str, form_data: dict, user=Depends(get_verified_user)
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    action = Functions.get_function_by_id(action_id)
 | 
				
			||||||
 | 
					    if not action:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=status.HTTP_404_NOT_FOUND,
 | 
				
			||||||
 | 
					            detail="Action not found",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    data = form_data
 | 
				
			||||||
 | 
					    model_id = data["model"]
 | 
				
			||||||
 | 
					    if model_id not in app.state.MODELS:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=status.HTTP_404_NOT_FOUND,
 | 
				
			||||||
 | 
					            detail="Model not found",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    model = app.state.MODELS[model_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    __event_emitter__ = await get_event_emitter(
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "chat_id": data["chat_id"],
 | 
				
			||||||
 | 
					            "message_id": data["id"],
 | 
				
			||||||
 | 
					            "session_id": data["session_id"],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    __event_call__ = await get_event_call(
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "chat_id": data["chat_id"],
 | 
				
			||||||
 | 
					            "message_id": data["id"],
 | 
				
			||||||
 | 
					            "session_id": data["session_id"],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if action_id in webui_app.state.FUNCTIONS:
 | 
				
			||||||
 | 
					        function_module = webui_app.state.FUNCTIONS[action_id]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        function_module, _, _ = load_function_module_by_id(action_id)
 | 
				
			||||||
 | 
					        webui_app.state.FUNCTIONS[action_id] = function_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
 | 
				
			||||||
 | 
					        valves = Functions.get_function_valves_by_id(action_id)
 | 
				
			||||||
 | 
					        function_module.valves = function_module.Valves(**(valves if valves else {}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if hasattr(function_module, "action"):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            action = function_module.action
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Get the signature of the function
 | 
				
			||||||
 | 
					            sig = inspect.signature(action)
 | 
				
			||||||
 | 
					            params = {"body": data}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Extra parameters to be passed to the function
 | 
				
			||||||
 | 
					            extra_params = {
 | 
				
			||||||
 | 
					                "__model__": model,
 | 
				
			||||||
 | 
					                "__id__": action_id,
 | 
				
			||||||
 | 
					                "__event_emitter__": __event_emitter__,
 | 
				
			||||||
 | 
					                "__event_call__": __event_call__,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Add extra params in contained in function signature
 | 
				
			||||||
 | 
					            for key, value in extra_params.items():
 | 
				
			||||||
 | 
					                if key in sig.parameters:
 | 
				
			||||||
 | 
					                    params[key] = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if "__user__" in sig.parameters:
 | 
				
			||||||
 | 
					                __user__ = {
 | 
				
			||||||
 | 
					                    "id": user.id,
 | 
				
			||||||
 | 
					                    "email": user.email,
 | 
				
			||||||
 | 
					                    "name": user.name,
 | 
				
			||||||
 | 
					                    "role": user.role,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    if hasattr(function_module, "UserValves"):
 | 
				
			||||||
 | 
					                        __user__["valves"] = function_module.UserValves(
 | 
				
			||||||
 | 
					                            **Functions.get_user_valves_by_id_and_user_id(
 | 
				
			||||||
 | 
					                                action_id, user.id
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                except Exception as e:
 | 
				
			||||||
 | 
					                    print(e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                params = {**params, "__user__": __user__}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if inspect.iscoroutinefunction(action):
 | 
				
			||||||
 | 
					                data = await action(**params)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                data = action(**params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            print(f"Error: {e}")
 | 
				
			||||||
 | 
					            return JSONResponse(
 | 
				
			||||||
 | 
					                status_code=status.HTTP_400_BAD_REQUEST,
 | 
				
			||||||
 | 
					                content={"detail": str(e)},
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
##################################
 | 
					##################################
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Task Endpoints
 | 
					# Task Endpoints
 | 
				
			||||||
 | 
				
			|||||||
@ -37,6 +37,7 @@
 | 
				
			|||||||
	import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
 | 
						import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte';
 | 
				
			||||||
	import Spinner from '$lib/components/common/Spinner.svelte';
 | 
						import Spinner from '$lib/components/common/Spinner.svelte';
 | 
				
			||||||
	import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
 | 
						import WebSearchResults from './ResponseMessage/WebSearchResults.svelte';
 | 
				
			||||||
 | 
						import Sparkles from '$lib/components/icons/Sparkles.svelte';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	export let message;
 | 
						export let message;
 | 
				
			||||||
	export let siblings;
 | 
						export let siblings;
 | 
				
			||||||
@ -1020,6 +1021,22 @@
 | 
				
			|||||||
														</svg>
 | 
																			</svg>
 | 
				
			||||||
													</button>
 | 
																		</button>
 | 
				
			||||||
												</Tooltip>
 | 
																	</Tooltip>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																	{#each model?.actions ?? [] as action}
 | 
				
			||||||
 | 
																		<Tooltip content={action.name} placement="bottom">
 | 
				
			||||||
 | 
																			<button
 | 
				
			||||||
 | 
																				type="button"
 | 
				
			||||||
 | 
																				class="{isLastMessage
 | 
				
			||||||
 | 
																					? 'visible'
 | 
				
			||||||
 | 
																					: 'invisible group-hover:visible'} p-1.5 hover:bg-black/5 dark:hover:bg-white/5 rounded-lg dark:hover:text-white hover:text-black transition regenerate-response-button"
 | 
				
			||||||
 | 
																				on:click={() => {
 | 
				
			||||||
 | 
																					console.log('action');
 | 
				
			||||||
 | 
																				}}
 | 
				
			||||||
 | 
																			>
 | 
				
			||||||
 | 
																				<Sparkles strokeWidth="2.1" className="size-4" />
 | 
				
			||||||
 | 
																			</button>
 | 
				
			||||||
 | 
																		</Tooltip>
 | 
				
			||||||
 | 
																	{/each}
 | 
				
			||||||
											{/if}
 | 
																{/if}
 | 
				
			||||||
										{/if}
 | 
															{/if}
 | 
				
			||||||
									{/if}
 | 
														{/if}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										19
									
								
								src/lib/components/icons/Sparkles.svelte
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								src/lib/components/icons/Sparkles.svelte
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
				
			|||||||
 | 
					<script lang="ts">
 | 
				
			||||||
 | 
						export let className = 'w-4 h-4';
 | 
				
			||||||
 | 
						export let strokeWidth = '1.5';
 | 
				
			||||||
 | 
					</script>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<svg
 | 
				
			||||||
 | 
						xmlns="http://www.w3.org/2000/svg"
 | 
				
			||||||
 | 
						fill="none"
 | 
				
			||||||
 | 
						viewBox="0 0 24 24"
 | 
				
			||||||
 | 
						stroke-width={strokeWidth}
 | 
				
			||||||
 | 
						stroke="currentColor"
 | 
				
			||||||
 | 
						class={className}
 | 
				
			||||||
 | 
					>
 | 
				
			||||||
 | 
						<path
 | 
				
			||||||
 | 
							stroke-linecap="round"
 | 
				
			||||||
 | 
							stroke-linejoin="round"
 | 
				
			||||||
 | 
							d="M9.813 15.904 9 18.75l-.813-2.846a4.5 4.5 0 0 0-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 0 0 3.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 0 0 3.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 0 0-3.09 3.09ZM18.259 8.715 18 9.75l-.259-1.035a3.375 3.375 0 0 0-2.455-2.456L14.25 6l1.036-.259a3.375 3.375 0 0 0 2.455-2.456L18 2.25l.259 1.035a3.375 3.375 0 0 0 2.456 2.456L21.75 6l-1.035.259a3.375 3.375 0 0 0-2.456 2.456ZM16.894 20.567 16.5 21.75l-.394-1.183a2.25 2.25 0 0 0-1.423-1.423L13.5 18.75l1.183-.394a2.25 2.25 0 0 0 1.423-1.423l.394-1.183.394 1.183a2.25 2.25 0 0 0 1.423 1.423l1.183.394-1.183.394a2.25 2.25 0 0 0-1.423 1.423Z"
 | 
				
			||||||
 | 
						/>
 | 
				
			||||||
 | 
					</svg>
 | 
				
			||||||
@ -122,12 +122,17 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		if (res) {
 | 
							if (res) {
 | 
				
			||||||
			if (func.is_global) {
 | 
								if (func.is_global) {
 | 
				
			||||||
				toast.success($i18n.t('Filter is now globally enabled'));
 | 
									func.type === 'filter'
 | 
				
			||||||
 | 
										? toast.success($i18n.t('Filter is now globally enabled'))
 | 
				
			||||||
 | 
										: toast.success($i18n.t('Function is now globally enabled'));
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				toast.success($i18n.t('Filter is now globally disabled'));
 | 
									func.type === 'filter'
 | 
				
			||||||
 | 
										? toast.success($i18n.t('Filter is now globally disabled'))
 | 
				
			||||||
 | 
										: toast.success($i18n.t('Function is now globally disabled'));
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			functions.set(await getFunctions(localStorage.token));
 | 
								functions.set(await getFunctions(localStorage.token));
 | 
				
			||||||
 | 
								models.set(await getModels(localStorage.token));
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	};
 | 
						};
 | 
				
			||||||
</script>
 | 
					</script>
 | 
				
			||||||
@ -294,7 +299,7 @@
 | 
				
			|||||||
						showDeleteConfirm = true;
 | 
											showDeleteConfirm = true;
 | 
				
			||||||
					}}
 | 
										}}
 | 
				
			||||||
					toggleGlobalHandler={() => {
 | 
										toggleGlobalHandler={() => {
 | 
				
			||||||
						if (func.type === 'filter') {
 | 
											if (['filter', 'action'].includes(func.type)) {
 | 
				
			||||||
							toggleGlobalHandler(func);
 | 
												toggleGlobalHandler(func);
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}}
 | 
										}}
 | 
				
			||||||
 | 
				
			|||||||
@ -48,7 +48,7 @@
 | 
				
			|||||||
			align="start"
 | 
								align="start"
 | 
				
			||||||
			transition={flyAndScale}
 | 
								transition={flyAndScale}
 | 
				
			||||||
		>
 | 
							>
 | 
				
			||||||
			{#if func.type === 'filter'}
 | 
								{#if ['filter', 'action'].includes(func.type)}
 | 
				
			||||||
				<div
 | 
									<div
 | 
				
			||||||
					class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointerrounded-md"
 | 
										class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointerrounded-md"
 | 
				
			||||||
				>
 | 
									>
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user