diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index f5fab34db..c1e0dc5dd 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -5,6 +5,7 @@ from typing import List, Union, Optional import time import logging from apps.webui.internal.db import DB, JSONField +from apps.webui.models.users import Users import json @@ -115,6 +116,46 @@ class FunctionsTable: for function in Function.select().where(Function.type == type) ] + def get_user_valves_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[dict]: + try: + user = Users.get_user_by_id(user_id) + + # Check if user has "functions" and "valves" settings + if "functions" not in user.settings: + user.settings["functions"] = {} + if "valves" not in user.settings["functions"]: + user.settings["functions"]["valves"] = {} + + return user.settings["functions"]["valves"].get(id, {}) + except Exception as e: + print(f"An error occurred: {e}") + return None + + def update_user_valves_by_id_and_user_id( + self, id: str, user_id: str, valves: dict + ) -> Optional[dict]: + try: + user = Users.get_user_by_id(user_id) + + # Check if user has "functions" and "valves" settings + if "functions" not in user.settings: + user.settings["functions"] = {} + if "valves" not in user.settings["functions"]: + user.settings["functions"]["valves"] = {} + + user.settings["functions"]["valves"][id] = valves + + # Update the user settings in the database + query = Users.update_user_by_id(user_id, {"settings": user.settings}) + query.execute() + + return user.settings["functions"]["valves"][id] + except Exception as e: + print(f"An error occurred: {e}") + return None + def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: try: query = Function.update( diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index e2db1e35f..1592c228f 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -5,6 +5,7 @@ from typing import List, Union, Optional import time import logging from apps.webui.internal.db import DB, JSONField +from apps.webui.models.users import Users import json @@ -106,6 +107,46 @@ class ToolsTable: def get_tools(self) -> List[ToolModel]: return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + def get_user_valves_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[dict]: + try: + user = Users.get_user_by_id(user_id) + + # Check if user has "tools" and "valves" settings + if "tools" not in user.settings: + user.settings["tools"] = {} + if "valves" not in user.settings["tools"]: + user.settings["tools"]["valves"] = {} + + return user.settings["tools"]["valves"].get(id, {}) + except Exception as e: + print(f"An error occurred: {e}") + return None + + def update_user_valves_by_id_and_user_id( + self, id: str, user_id: str, valves: dict + ) -> Optional[dict]: + try: + user = Users.get_user_by_id(user_id) + + # Check if user has "tools" and "valves" settings + if "tools" not in user.settings: + user.settings["tools"] = {} + if "valves" not in user.settings["tools"]: + user.settings["tools"]["valves"] = {} + + user.settings["tools"]["valves"][id] = valves + + # Update the user settings in the database + query = Users.update_user_by_id(user_id, {"settings": user.settings}) + query.execute() + + return user.settings["tools"]["valves"][id] + except Exception as e: + print(f"An error occurred: {e}") + return None + def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: query = Tool.update( diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index ea5fde336..53b9ab130 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -117,6 +117,94 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# FunctionUserValves +############################ + + +@router.get("/id/{id}/valves/user", response_model=Optional[dict]) +async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)): + function = Functions.get_function_by_id(id) + if function: + try: + user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id) + return user_valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) +async def get_function_user_valves_spec_by_id( + request: Request, id: str, user=Depends(get_verified_user) +): + function = Functions.get_tool_by_id(id) + if function: + if id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[id] + else: + function_module, function_type = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module + + if hasattr(function_module, "UserValves"): + UserValves = function_module.UserValves + return UserValves.schema() + return None + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +@router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) +async def update_function_user_valves_by_id( + request: Request, id: str, form_data: dict, user=Depends(get_verified_user) +): + + function = Functions.get_tool_by_id(id) + + if function: + if id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[id] + else: + function_module, function_type = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module + + if hasattr(function_module, "UserValves"): + UserValves = function_module.UserValves + + try: + user_valves = UserValves(**form_data) + Functions.update_user_valves_by_id_and_user_id( + id, user.id, user_valves.model_dump() + ) + return user_valves.model_dump() + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateFunctionById ############################ diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 49ddf4af0..c71d2f01c 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -6,10 +6,12 @@ from fastapi import APIRouter from pydantic import BaseModel import json + +from apps.webui.models.users import Users from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.utils import load_toolkit_module_by_id -from utils.utils import get_current_user, get_admin_user +from utils.utils import get_admin_user, get_verified_user from utils.tools import get_tools_specs from constants import ERROR_MESSAGES @@ -32,7 +34,7 @@ router = APIRouter() @router.get("/", response_model=List[ToolResponse]) -async def get_toolkits(user=Depends(get_current_user)): +async def get_toolkits(user=Depends(get_verified_user)): toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits @@ -121,6 +123,93 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# ToolUserValves +############################ + + +@router.get("/id/{id}/valves/user", response_model=Optional[dict]) +async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)): + toolkit = Tools.get_tool_by_id(id) + if toolkit: + try: + user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) + return user_valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) +async def get_toolkit_user_valves_spec_by_id( + request: Request, id: str, user=Depends(get_verified_user) +): + toolkit = Tools.get_tool_by_id(id) + if toolkit: + if id in request.app.state.TOOLS: + toolkit_module = request.app.state.TOOLS[id] + else: + toolkit_module = load_toolkit_module_by_id(id) + request.app.state.TOOLS[id] = toolkit_module + + if hasattr(toolkit_module, "UserValves"): + UserValves = toolkit_module.UserValves + return UserValves.schema() + return None + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +@router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) +async def update_toolkit_user_valves_by_id( + request: Request, id: str, form_data: dict, user=Depends(get_verified_user) +): + toolkit = Tools.get_tool_by_id(id) + + if toolkit: + if id in request.app.state.TOOLS: + toolkit_module = request.app.state.TOOLS[id] + else: + toolkit_module = load_toolkit_module_by_id(id) + request.app.state.TOOLS[id] = toolkit_module + + if hasattr(toolkit_module, "UserValves"): + UserValves = toolkit_module.UserValves + + try: + user_valves = UserValves(**form_data) + Tools.update_user_valves_by_id_and_user_id( + id, user.id, user_valves.model_dump() + ) + return user_valves.model_dump() + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateToolkitById ############################ diff --git a/src/lib/components/chat/Settings/Valves.svelte b/src/lib/components/chat/Settings/Valves.svelte index 159688ee9..cfa03d907 100644 --- a/src/lib/components/chat/Settings/Valves.svelte +++ b/src/lib/components/chat/Settings/Valves.svelte @@ -29,25 +29,28 @@ }} >
-
- +
+
+ +
+ {$i18n.t('Manage Valves')} +
+
- +
+ +
+
-
+
+
+
diff --git a/src/routes/(app)/+layout.svelte b/src/routes/(app)/+layout.svelte index 8973b9eb7..875ebf4eb 100644 --- a/src/routes/(app)/+layout.svelte +++ b/src/routes/(app)/+layout.svelte @@ -29,13 +29,15 @@ showChangelog, config, showCallOverlay, - tools + tools, + functions } from '$lib/stores'; import SettingsModal from '$lib/components/chat/SettingsModal.svelte'; import Sidebar from '$lib/components/layout/Sidebar.svelte'; import ChangelogModal from '$lib/components/ChangelogModal.svelte'; import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte'; + import { getFunctions } from '$lib/apis/functions'; const i18n = getContext('i18n'); @@ -93,6 +95,9 @@ (async () => { tools.set(await getTools(localStorage.token)); })(), + (async () => { + functions.set(await getFunctions(localStorage.token)); + })(), (async () => { banners.set(await getBanners(localStorage.token)); })(),