From d8c112d8b0779011e17cfed39f9dacd7b4b47b72 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 18:34:42 -0700 Subject: [PATCH] feat: function toggle support --- backend/apps/webui/main.py | 2 +- backend/apps/webui/models/functions.py | 60 ++++- backend/apps/webui/routers/functions.py | 77 ++++++ backend/apps/webui/routers/tools.py | 50 ++++ backend/main.py | 1 - src/lib/apis/functions/index.ts | 99 +++++++ src/lib/components/workspace/Functions.svelte | 10 +- .../components/workspace/ValvesModal.svelte | 248 ++++-------------- 8 files changed, 338 insertions(+), 209 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index a9f7fb286..a8f45aff0 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -103,7 +103,7 @@ async def get_status(): async def get_pipe_models(): - pipes = Functions.get_functions_by_type("pipe") + pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 779eef9ef..6741f2d10 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -48,7 +48,7 @@ class FunctionModel(BaseModel): type: str content: str meta: FunctionMeta - is_active: bool + is_active: bool = False updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -115,16 +115,56 @@ class FunctionsTable: except: return None - def get_functions(self) -> List[FunctionModel]: - return [ - FunctionModel(**model_to_dict(function)) for function in Function.select() - ] + def get_functions(self, active_only=False) -> List[FunctionModel]: + if active_only: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select().where(Function.is_active == True) + ] + else: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select() + ] - def get_functions_by_type(self, type: str) -> List[FunctionModel]: - return [ - FunctionModel(**model_to_dict(function)) - for function in Function.select().where(Function.type == type) - ] + def get_functions_by_type( + self, type: str, active_only=False + ) -> List[FunctionModel]: + if active_only: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select().where( + Function.type == type, Function.is_active == True + ) + ] + else: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select().where(Function.type == type) + ] + + def get_function_valves_by_id(self, id: str) -> Optional[FunctionValves]: + try: + function = Function.get(Function.id == id) + return FunctionValves(**model_to_dict(function)) + except Exception as e: + print(f"An error occurred: {e}") + return None + + def update_function_valves_by_id( + self, id: str, valves: dict + ) -> Optional[FunctionValves]: + try: + query = Function.update( + **{"valves": valves}, + updated_at=int(time.time()), + ).where(Function.id == id) + query.execute() + + function = Function.get(Function.id == id) + return FunctionValves(**model_to_dict(function)) + except: + return None def get_user_valves_by_id_and_user_id( self, id: str, user_id: str diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index acf1894fd..8c6454eac 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -117,6 +117,56 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# GetFunctionValves +############################ + + +@router.get("/id/{id}/valves", response_model=Optional[dict]) +async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + if function: + try: + valves = Functions.get_function_valves_by_id(id) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolValves +############################ + + +@router.post("/id/{id}/valves/update", response_model=Optional[dict]) +async def update_toolkit_valves_by_id( + id: str, form_data: dict, user=Depends(get_admin_user) +): + function = Functions.get_function_by_id(id) + if function: + try: + valves = Functions.update_function_valves_by_id(id, form_data) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # FunctionUserValves ############################ @@ -204,6 +254,33 @@ async def update_function_user_valves_by_id( ) +############################ +# ToggleFunctionById +############################ + + +@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) +async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + if function: + function = Functions.update_function_by_id( + id, {"is_active": not function.is_active} + ) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateFunctionById ############################ diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index c71d2f01c..7988acc86 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -123,6 +123,56 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# GetToolValves +############################ + + +@router.get("/id/{id}/valves", response_model=Optional[dict]) +async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(id) + if toolkit: + try: + valves = Tools.get_tool_valves_by_id(id) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolValves +############################ + + +@router.post("/id/{id}/valves/update", response_model=Optional[dict]) +async def update_toolkit_valves_by_id( + id: str, form_data: dict, user=Depends(get_admin_user) +): + toolkit = Tools.get_tool_by_id(id) + if toolkit: + try: + valves = Tools.update_tool_valves_by_id(id, form_data) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # ToolUserValves ############################ diff --git a/backend/main.py b/backend/main.py index 7dda4e557..02552f209 100644 --- a/backend/main.py +++ b/backend/main.py @@ -863,7 +863,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = model.get("pipe") if pipe: - async def job(): pipe_id = form_data["model"] if "." in pipe_id: diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts index 530702a3e..c8607b091 100644 --- a/src/lib/apis/functions/index.ts +++ b/src/lib/apis/functions/index.ts @@ -192,6 +192,105 @@ export const deleteFunctionById = async (token: string, id: string) => { return res; }; +export const toggleFunctionById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFunctionValvesById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateFunctionValvesById = async (token: string, id: string, valves: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...valves + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getUserValvesById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index b389d6eab..fc8e7a451 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -13,7 +13,8 @@ deleteFunctionById, exportFunctions, getFunctionById, - getFunctions + getFunctions, + toggleFunctionById } from '$lib/apis/functions'; import ArrowDownTray from '../icons/ArrowDownTray.svelte'; @@ -224,7 +225,12 @@
- + { + toggleFunctionById(localStorage.token, func.id); + }} + />
diff --git a/src/lib/components/workspace/ValvesModal.svelte b/src/lib/components/workspace/ValvesModal.svelte index 58434a559..f67d948a9 100644 --- a/src/lib/components/workspace/ValvesModal.svelte +++ b/src/lib/components/workspace/ValvesModal.svelte @@ -5,16 +5,16 @@ import { addUser } from '$lib/apis/auths'; import Modal from '../common/Modal.svelte'; - import { WEBUI_BASE_URL } from '$lib/constants'; const i18n = getContext('i18n'); const dispatch = createEventDispatcher(); export let show = false; + export let type = 'tool'; + export let id = null; + let loading = false; - let tab = ''; - let inputFiles; let _user = { name: '', @@ -23,96 +23,11 @@ role: 'user' }; - $: if (show) { - _user = { - name: '', - email: '', - password: '', - role: 'user' - }; - } - const submitHandler = async () => { const stopLoading = () => { dispatch('save'); loading = false; }; - - if (tab === '') { - loading = true; - - const res = await addUser( - localStorage.token, - _user.name, - _user.email, - _user.password, - _user.role - ).catch((error) => { - toast.error(error); - }); - - if (res) { - stopLoading(); - show = false; - } - } else { - if (inputFiles) { - loading = true; - - const file = inputFiles[0]; - const reader = new FileReader(); - - reader.onload = async (e) => { - const csv = e.target.result; - const rows = csv.split('\n'); - - let userCount = 0; - - for (const [idx, row] of rows.entries()) { - const columns = row.split(',').map((col) => col.trim()); - console.log(idx, columns); - - if (idx > 0) { - if ( - columns.length === 4 && - ['admin', 'user', 'pending'].includes(columns[3].toLowerCase()) - ) { - const res = await addUser( - localStorage.token, - columns[0], - columns[1], - columns[2], - columns[3].toLowerCase() - ).catch((error) => { - toast.error(`Row ${idx + 1}: ${error}`); - return null; - }); - - if (res) { - userCount = userCount + 1; - } - } else { - toast.error(`Row ${idx + 1}: invalid format.`); - } - } - } - - toast.success(`Successfully imported ${userCount} users.`); - inputFiles = null; - const uploadInputElement = document.getElementById('upload-user-csv-input'); - - if (uploadInputElement) { - uploadInputElement.value = null; - } - - stopLoading(); - }; - - reader.readAsText(file); - } else { - toast.error($i18n.t('File not found.')); - } - } }; @@ -147,126 +62,69 @@ submitHandler(); }} > -
- - - -
- {#if tab === ''} -
-
{$i18n.t('Role')}
+
+
{$i18n.t('Role')}
-
- -
+
+
+
-
-
{$i18n.t('Name')}
+
+
{$i18n.t('Name')}
-
- -
+
+
+
-
+
-
-
{$i18n.t('Email')}
+
+
{$i18n.t('Email')}
-
- -
+
+
+
-
-
{$i18n.t('Password')}
+
+
{$i18n.t('Password')}
-
- -
+
+
- {:else if tab === 'import'} -
-
- - - -
- -
- ⓘ {$i18n.t( - 'Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.' - )} - - {$i18n.t('Click here to download user import template file.')} - -
-
- {/if} +