feat: user valves endpoints

This commit is contained in:
Timothy J. Baek 2024-06-22 11:26:33 -07:00
parent 8345bb55d4
commit 15fc23df87
6 changed files with 304 additions and 25 deletions

View File

@ -5,6 +5,7 @@ from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import DB, JSONField
from apps.webui.models.users import Users
import json import json
@ -115,6 +116,46 @@ class FunctionsTable:
for function in Function.select().where(Function.type == type) for function in Function.select().where(Function.type == type)
] ]
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
# Check if user has "functions" and "valves" settings
if "functions" not in user.settings:
user.settings["functions"] = {}
if "valves" not in user.settings["functions"]:
user.settings["functions"]["valves"] = {}
return user.settings["functions"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
# Check if user has "functions" and "valves" settings
if "functions" not in user.settings:
user.settings["functions"] = {}
if "valves" not in user.settings["functions"]:
user.settings["functions"]["valves"] = {}
user.settings["functions"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user.settings})
query.execute()
return user.settings["functions"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try: try:
query = Function.update( query = Function.update(

View File

@ -5,6 +5,7 @@ from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import DB, JSONField
from apps.webui.models.users import Users
import json import json
@ -106,6 +107,46 @@ class ToolsTable:
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
# Check if user has "tools" and "valves" settings
if "tools" not in user.settings:
user.settings["tools"] = {}
if "valves" not in user.settings["tools"]:
user.settings["tools"]["valves"] = {}
return user.settings["tools"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
# Check if user has "tools" and "valves" settings
if "tools" not in user.settings:
user.settings["tools"] = {}
if "valves" not in user.settings["tools"]:
user.settings["tools"]["valves"] = {}
user.settings["tools"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user.settings})
query.execute()
return user.settings["tools"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
query = Tool.update( query = Tool.update(

View File

@ -117,6 +117,94 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
) )
############################
# FunctionUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
function = Functions.get_function_by_id(id)
if function:
try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
function = Functions.get_tool_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
function = Functions.get_tool_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
try:
user_valves = UserValves(**form_data)
Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################ ############################
# UpdateFunctionById # UpdateFunctionById
############################ ############################

View File

@ -6,10 +6,12 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -32,7 +34,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse]) @router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)): async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()] toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits return toolkits
@ -121,6 +123,93 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
) )
############################
# ToolUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_toolkit_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_toolkit_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
try:
user_valves = UserValves(**form_data)
Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################ ############################
# UpdateToolkitById # UpdateToolkitById
############################ ############################

View File

@ -29,25 +29,28 @@
}} }}
> >
<div class="flex flex-col pr-1.5 overflow-y-scroll max-h-[25rem]"> <div class="flex flex-col pr-1.5 overflow-y-scroll max-h-[25rem]">
<div class="flex text-center text-sm font-medium rounded-xl bg-transparent/10 p-1 mb-2"> <div>
<button <div class="flex items-center justify-between mb-2">
class="w-full rounded-lg p-1.5 {tab === 'tools' ? 'bg-gray-50 dark:bg-gray-850' : ''}" <Tooltip content="">
type="button" <div class="text-sm font-medium">
on:click={() => { {$i18n.t('Manage Valves')}
tab = 'tools'; </div>
}}>{$i18n.t('Tools')}</button </Tooltip>
>
<button <div class=" self-end">
class="w-full rounded-lg p-1 {tab === 'functions' ? 'bg-gray-50 dark:bg-gray-850' : ''}" <select
type="button" class=" dark:bg-gray-900 w-fit pr-8 rounded text-xs bg-transparent outline-none text-right"
on:click={() => { bind:value={tab}
tab = 'functions'; placeholder="Select"
}}>{$i18n.t('Functions')}</button
> >
<option value="tools">{$i18n.t('Tools')}</option>
<option value="functions">{$i18n.t('Functions')}</option>
</select>
</div>
</div>
</div> </div>
<div class="space-y-1 px-1"> <div class="space-y-1">
<div class="flex gap-2"> <div class="flex gap-2">
<div class="flex-1"> <div class="flex-1">
<select <select
@ -57,18 +60,30 @@
await tick(); await tick();
}} }}
> >
{#if tab === 'tools'}
<option value="" selected disabled class="bg-gray-100 dark:bg-gray-700" <option value="" selected disabled class="bg-gray-100 dark:bg-gray-700"
>{$i18n.t('Select a tool/function')}</option >{$i18n.t('Select a tool')}</option
> >
{#each $tools as tool, toolIdx} {#each $tools as tool, toolIdx}
<option value={tool.id} class="bg-gray-100 dark:bg-gray-700">{tool.name}</option> <option value={tool.id} class="bg-gray-100 dark:bg-gray-700">{tool.name}</option>
{/each} {/each}
{:else if tab === 'functions'}
<option value="" selected disabled class="bg-gray-100 dark:bg-gray-700"
>{$i18n.t('Select a function')}</option
>
{#each $functions as func, funcIdx}
<option value={func.id} class="bg-gray-100 dark:bg-700">{func.name}</option>
{/each}
{/if}
</select> </select>
</div> </div>
</div> </div>
</div> </div>
<hr class="dark:border-gray-800 my-3 w-full" />
<div> <div>
<div class="flex items-center justify-between mb-1" /> <div class="flex items-center justify-between mb-1" />
</div> </div>

View File

@ -29,13 +29,15 @@
showChangelog, showChangelog,
config, config,
showCallOverlay, showCallOverlay,
tools tools,
functions
} from '$lib/stores'; } from '$lib/stores';
import SettingsModal from '$lib/components/chat/SettingsModal.svelte'; import SettingsModal from '$lib/components/chat/SettingsModal.svelte';
import Sidebar from '$lib/components/layout/Sidebar.svelte'; import Sidebar from '$lib/components/layout/Sidebar.svelte';
import ChangelogModal from '$lib/components/ChangelogModal.svelte'; import ChangelogModal from '$lib/components/ChangelogModal.svelte';
import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte'; import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte';
import { getFunctions } from '$lib/apis/functions';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -93,6 +95,9 @@
(async () => { (async () => {
tools.set(await getTools(localStorage.token)); tools.set(await getTools(localStorage.token));
})(), })(),
(async () => {
functions.set(await getFunctions(localStorage.token));
})(),
(async () => { (async () => {
banners.set(await getBanners(localStorage.token)); banners.set(await getBanners(localStorage.token));
})(), })(),