From 120b1857b21ea25f7a2dddc5e6e288f3ce0a6ebe Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 18:05:33 -0700 Subject: [PATCH 1/6] enh: valves --- .../016_add_valves_and_is_active.py | 50 +++++++++++++++++++ backend/apps/webui/models/functions.py | 8 +++ backend/apps/webui/models/tools.py | 26 ++++++++++ src/routes/(app)/workspace/+layout.svelte | 2 +- 4 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py diff --git a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py new file mode 100644 index 000000000..e3af521b7 --- /dev/null +++ b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py @@ -0,0 +1,50 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields("tool", valves=pw.TextField(null=True)) + migrator.add_fields("function", valves=pw.TextField(null=True)) + migrator.add_fields("function", is_active=pw.BooleanField(default=False)) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("tool", "valves") + migrator.remove_fields("function", "valves") + migrator.remove_fields("function", "is_active") diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index cd6320f95..779eef9ef 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -28,6 +28,8 @@ class Function(Model): type = TextField() content = TextField() meta = JSONField() + valves = JSONField() + is_active = BooleanField(default=False) updated_at = BigIntegerField() created_at = BigIntegerField() @@ -46,6 +48,7 @@ class FunctionModel(BaseModel): type: str content: str meta: FunctionMeta + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -61,6 +64,7 @@ class FunctionResponse(BaseModel): type: str name: str meta: FunctionMeta + is_active: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -72,6 +76,10 @@ class FunctionForm(BaseModel): meta: FunctionMeta +class FunctionValves(BaseModel): + valves: Optional[dict] = None + + class FunctionsTable: def __init__(self, db): self.db = db diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index ab322ac14..0f5755e39 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -28,6 +28,7 @@ class Tool(Model): content = TextField() specs = JSONField() meta = JSONField() + valves = JSONField() updated_at = BigIntegerField() created_at = BigIntegerField() @@ -71,6 +72,10 @@ class ToolForm(BaseModel): meta: ToolMeta +class ToolValves(BaseModel): + valves: Optional[dict] = None + + class ToolsTable: def __init__(self, db): self.db = db @@ -109,6 +114,27 @@ class ToolsTable: def get_tools(self) -> List[ToolModel]: return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + def get_tool_valves_by_id(self, id: str) -> Optional[ToolValves]: + try: + tool = Tool.get(Tool.id == id) + return ToolValves(**model_to_dict(tool)) + except Exception as e: + print(f"An error occurred: {e}") + return None + + def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: + try: + query = Tool.update( + **{"valves": valves}, + updated_at=int(time.time()), + ).where(Tool.id == id) + query.execute() + + tool = Tool.get(Tool.id == id) + return ToolValves(**model_to_dict(tool)) + except: + return None + def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: diff --git a/src/routes/(app)/workspace/+layout.svelte b/src/routes/(app)/workspace/+layout.svelte index 46e0f63c4..794bd7ed5 100644 --- a/src/routes/(app)/workspace/+layout.svelte +++ b/src/routes/(app)/workspace/+layout.svelte @@ -9,7 +9,7 @@ const i18n = getContext('i18n'); onMount(async () => { - functions.set(await getFunctions(localStorage.token)); + // functions.set(await getFunctions(localStorage.token)); }); From 3034f3d310f22a992b26be2aed5ba199a61f24be Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 18:06:41 -0700 Subject: [PATCH 2/6] refac: styling --- src/lib/components/chat/Settings/Valves.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/components/chat/Settings/Valves.svelte b/src/lib/components/chat/Settings/Valves.svelte index 1001ddb28..d834bd014 100644 --- a/src/lib/components/chat/Settings/Valves.svelte +++ b/src/lib/components/chat/Settings/Valves.svelte @@ -203,7 +203,7 @@ {#if (valves[property] ?? null) !== null} -
+
Date: Sun, 23 Jun 2024 18:34:42 -0700 Subject: [PATCH 3/6] feat: function toggle support --- backend/apps/webui/main.py | 2 +- backend/apps/webui/models/functions.py | 60 ++++- backend/apps/webui/routers/functions.py | 77 ++++++ backend/apps/webui/routers/tools.py | 50 ++++ backend/main.py | 1 - src/lib/apis/functions/index.ts | 99 +++++++ src/lib/components/workspace/Functions.svelte | 10 +- .../components/workspace/ValvesModal.svelte | 248 ++++-------------- 8 files changed, 338 insertions(+), 209 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index a9f7fb286..a8f45aff0 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -103,7 +103,7 @@ async def get_status(): async def get_pipe_models(): - pipes = Functions.get_functions_by_type("pipe") + pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 779eef9ef..6741f2d10 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -48,7 +48,7 @@ class FunctionModel(BaseModel): type: str content: str meta: FunctionMeta - is_active: bool + is_active: bool = False updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -115,16 +115,56 @@ class FunctionsTable: except: return None - def get_functions(self) -> List[FunctionModel]: - return [ - FunctionModel(**model_to_dict(function)) for function in Function.select() - ] + def get_functions(self, active_only=False) -> List[FunctionModel]: + if active_only: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select().where(Function.is_active == True) + ] + else: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select() + ] - def get_functions_by_type(self, type: str) -> List[FunctionModel]: - return [ - FunctionModel(**model_to_dict(function)) - for function in Function.select().where(Function.type == type) - ] + def get_functions_by_type( + self, type: str, active_only=False + ) -> List[FunctionModel]: + if active_only: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select().where( + Function.type == type, Function.is_active == True + ) + ] + else: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select().where(Function.type == type) + ] + + def get_function_valves_by_id(self, id: str) -> Optional[FunctionValves]: + try: + function = Function.get(Function.id == id) + return FunctionValves(**model_to_dict(function)) + except Exception as e: + print(f"An error occurred: {e}") + return None + + def update_function_valves_by_id( + self, id: str, valves: dict + ) -> Optional[FunctionValves]: + try: + query = Function.update( + **{"valves": valves}, + updated_at=int(time.time()), + ).where(Function.id == id) + query.execute() + + function = Function.get(Function.id == id) + return FunctionValves(**model_to_dict(function)) + except: + return None def get_user_valves_by_id_and_user_id( self, id: str, user_id: str diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index acf1894fd..8c6454eac 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -117,6 +117,56 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# GetFunctionValves +############################ + + +@router.get("/id/{id}/valves", response_model=Optional[dict]) +async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + if function: + try: + valves = Functions.get_function_valves_by_id(id) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolValves +############################ + + +@router.post("/id/{id}/valves/update", response_model=Optional[dict]) +async def update_toolkit_valves_by_id( + id: str, form_data: dict, user=Depends(get_admin_user) +): + function = Functions.get_function_by_id(id) + if function: + try: + valves = Functions.update_function_valves_by_id(id, form_data) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # FunctionUserValves ############################ @@ -204,6 +254,33 @@ async def update_function_user_valves_by_id( ) +############################ +# ToggleFunctionById +############################ + + +@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) +async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + if function: + function = Functions.update_function_by_id( + id, {"is_active": not function.is_active} + ) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateFunctionById ############################ diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index c71d2f01c..7988acc86 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -123,6 +123,56 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# GetToolValves +############################ + + +@router.get("/id/{id}/valves", response_model=Optional[dict]) +async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(id) + if toolkit: + try: + valves = Tools.get_tool_valves_by_id(id) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateToolValves +############################ + + +@router.post("/id/{id}/valves/update", response_model=Optional[dict]) +async def update_toolkit_valves_by_id( + id: str, form_data: dict, user=Depends(get_admin_user) +): + toolkit = Tools.get_tool_by_id(id) + if toolkit: + try: + valves = Tools.update_tool_valves_by_id(id, form_data) + return valves + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # ToolUserValves ############################ diff --git a/backend/main.py b/backend/main.py index 7dda4e557..02552f209 100644 --- a/backend/main.py +++ b/backend/main.py @@ -863,7 +863,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = model.get("pipe") if pipe: - async def job(): pipe_id = form_data["model"] if "." in pipe_id: diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts index 530702a3e..c8607b091 100644 --- a/src/lib/apis/functions/index.ts +++ b/src/lib/apis/functions/index.ts @@ -192,6 +192,105 @@ export const deleteFunctionById = async (token: string, id: string) => { return res; }; +export const toggleFunctionById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFunctionValvesById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateFunctionValvesById = async (token: string, id: string, valves: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...valves + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getUserValvesById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index b389d6eab..fc8e7a451 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -13,7 +13,8 @@ deleteFunctionById, exportFunctions, getFunctionById, - getFunctions + getFunctions, + toggleFunctionById } from '$lib/apis/functions'; import ArrowDownTray from '../icons/ArrowDownTray.svelte'; @@ -224,7 +225,12 @@
- + { + toggleFunctionById(localStorage.token, func.id); + }} + />
diff --git a/src/lib/components/workspace/ValvesModal.svelte b/src/lib/components/workspace/ValvesModal.svelte index 58434a559..f67d948a9 100644 --- a/src/lib/components/workspace/ValvesModal.svelte +++ b/src/lib/components/workspace/ValvesModal.svelte @@ -5,16 +5,16 @@ import { addUser } from '$lib/apis/auths'; import Modal from '../common/Modal.svelte'; - import { WEBUI_BASE_URL } from '$lib/constants'; const i18n = getContext('i18n'); const dispatch = createEventDispatcher(); export let show = false; + export let type = 'tool'; + export let id = null; + let loading = false; - let tab = ''; - let inputFiles; let _user = { name: '', @@ -23,96 +23,11 @@ role: 'user' }; - $: if (show) { - _user = { - name: '', - email: '', - password: '', - role: 'user' - }; - } - const submitHandler = async () => { const stopLoading = () => { dispatch('save'); loading = false; }; - - if (tab === '') { - loading = true; - - const res = await addUser( - localStorage.token, - _user.name, - _user.email, - _user.password, - _user.role - ).catch((error) => { - toast.error(error); - }); - - if (res) { - stopLoading(); - show = false; - } - } else { - if (inputFiles) { - loading = true; - - const file = inputFiles[0]; - const reader = new FileReader(); - - reader.onload = async (e) => { - const csv = e.target.result; - const rows = csv.split('\n'); - - let userCount = 0; - - for (const [idx, row] of rows.entries()) { - const columns = row.split(',').map((col) => col.trim()); - console.log(idx, columns); - - if (idx > 0) { - if ( - columns.length === 4 && - ['admin', 'user', 'pending'].includes(columns[3].toLowerCase()) - ) { - const res = await addUser( - localStorage.token, - columns[0], - columns[1], - columns[2], - columns[3].toLowerCase() - ).catch((error) => { - toast.error(`Row ${idx + 1}: ${error}`); - return null; - }); - - if (res) { - userCount = userCount + 1; - } - } else { - toast.error(`Row ${idx + 1}: invalid format.`); - } - } - } - - toast.success(`Successfully imported ${userCount} users.`); - inputFiles = null; - const uploadInputElement = document.getElementById('upload-user-csv-input'); - - if (uploadInputElement) { - uploadInputElement.value = null; - } - - stopLoading(); - }; - - reader.readAsText(file); - } else { - toast.error($i18n.t('File not found.')); - } - } }; @@ -147,126 +62,69 @@ submitHandler(); }} > -
- - - -
- {#if tab === ''} -
-
{$i18n.t('Role')}
+
+
{$i18n.t('Role')}
-
- -
+
+
+
-
-
{$i18n.t('Name')}
+
+
{$i18n.t('Name')}
-
- -
+
+
+
-
+
-
-
{$i18n.t('Email')}
+
+
{$i18n.t('Email')}
-
- -
+
+
+
-
-
{$i18n.t('Password')}
+
+
{$i18n.t('Password')}
-
- -
+
+
- {:else if tab === 'import'} -
-
- - - -
- -
- ⓘ {$i18n.t( - 'Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.' - )} - - {$i18n.t('Click here to download user import template file.')} - -
-
- {/if} +
From 3a629ffe0009cf3cbceccc6af53f0da03cbeb9c2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 18:39:27 -0700 Subject: [PATCH 4/6] feat: global filter --- backend/main.py | 112 ++++++++++-------- src/lib/components/workspace/Functions.svelte | 3 +- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/backend/main.py b/backend/main.py index 02552f209..2a44d2029 100644 --- a/backend/main.py +++ b/backend/main.py @@ -376,70 +376,77 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) model = app.state.MODELS[model_id] + filter_ids = [ + function.id + for function in Functions.get_functions_by_type( + "filter", active_only=True + ) + ] # Check if the model has any filters if "info" in model and "meta" in model["info"]: - for filter_id in model["info"]["meta"].get("filterIds", []): - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type = load_function_module_by_id( - filter_id - ) - webui_app.state.FUNCTIONS[filter_id] = function_module + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module - try: - if hasattr(function_module, "inlet"): - inlet = function_module.inlet + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": data} + try: + if hasattr(function_module, "inlet"): + inlet = function_module.inlet - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + # Get the signature of the function + sig = inspect.signature(inlet) + params = {"body": data} - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = ( - function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id ) - except Exception as e: - print(e) + ) + except Exception as e: + print(e) - params = {**params, "__user__": __user__} + params = {**params, "__user__": __user__} - if "__id__" in sig.parameters: - params = { - **params, - "__id__": filter_id, - } + if "__id__" in sig.parameters: + params = { + **params, + "__id__": filter_id, + } - if inspect.iscoroutinefunction(inlet): - data = await inlet(**params) - else: - data = inlet(**params) + if inspect.iscoroutinefunction(inlet): + data = await inlet(**params) + else: + data = inlet(**params) - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) # Set the task model task_model_id = data["model"] @@ -863,6 +870,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = model.get("pipe") if pipe: + async def job(): pipe_id = form_data["model"] if "." in pipe_id: diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index fc8e7a451..75e0ce4ff 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -227,8 +227,9 @@
{ + on:change={async (e) => { toggleFunctionById(localStorage.token, func.id); + models.set(await getModels(localStorage.token)); }} />
From 627705a347bcddb6893b3a1b23cda6636d22d24e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 23 Jun 2024 19:02:27 -0700 Subject: [PATCH 5/6] feat: valves --- backend/apps/webui/routers/functions.py | 68 +++++-- backend/apps/webui/routers/tools.py | 63 +++++- src/lib/apis/functions/index.ts | 32 +++ src/lib/apis/tools/index.ts | 99 ++++++++++ src/lib/components/workspace/Functions.svelte | 10 + src/lib/components/workspace/Tools.svelte | 10 + .../components/workspace/ValvesModal.svelte | 187 +++++++++++------- 7 files changed, 377 insertions(+), 92 deletions(-) diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index 8c6454eac..4da68a052 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -127,8 +127,8 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): function = Functions.get_function_by_id(id) if function: try: - valves = Functions.get_function_valves_by_id(id) - return valves + function_valves = Functions.get_function_valves_by_id(id) + return function_valves.valves except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -142,24 +142,70 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): ############################ -# UpdateToolValves +# GetFunctionValvesSpec +############################ + + +@router.get("/id/{id}/valves/spec", response_model=Optional[dict]) +async def get_function_valves_spec_by_id( + request: Request, id: str, user=Depends(get_admin_user) +): + function = Functions.get_function_by_id(id) + if function: + if id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[id] + else: + function_module, function_type = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module + + if hasattr(function_module, "Valves"): + Valves = function_module.Valves + return Valves.schema() + return None + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateFunctionValves ############################ @router.post("/id/{id}/valves/update", response_model=Optional[dict]) -async def update_toolkit_valves_by_id( - id: str, form_data: dict, user=Depends(get_admin_user) +async def update_function_valves_by_id( + request: Request, id: str, form_data: dict, user=Depends(get_admin_user) ): function = Functions.get_function_by_id(id) if function: - try: - valves = Functions.update_function_valves_by_id(id, form_data) - return valves - except Exception as e: + + if id in request.app.state.FUNCTIONS: + function_module = request.app.state.FUNCTIONS[id] + else: + function_module, function_type = load_function_module_by_id(id) + request.app.state.FUNCTIONS[id] = function_module + + if hasattr(function_module, "Valves"): + Valves = function_module.Valves + + try: + valves = Valves(**form_data) + Functions.update_function_valves_by_id(id, valves.model_dump()) + return valves.model_dump() + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, ) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 7988acc86..7ddcf3ed9 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -133,8 +133,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): toolkit = Tools.get_tool_by_id(id) if toolkit: try: - valves = Tools.get_tool_valves_by_id(id) - return valves + tool_valves = Tools.get_tool_valves_by_id(id) + return tool_valves.valves except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -147,6 +147,34 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# GetToolValvesSpec +############################ + + +@router.get("/id/{id}/valves/spec", response_model=Optional[dict]) +async def get_toolkit_valves_spec_by_id( + request: Request, id: str, user=Depends(get_admin_user) +): + toolkit = Tools.get_tool_by_id(id) + if toolkit: + if id in request.app.state.TOOLS: + toolkit_module = request.app.state.TOOLS[id] + else: + toolkit_module = load_toolkit_module_by_id(id) + request.app.state.TOOLS[id] = toolkit_module + + if hasattr(toolkit_module, "UserValves"): + UserValves = toolkit_module.UserValves + return UserValves.schema() + return None + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateToolValves ############################ @@ -154,18 +182,35 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/valves/update", response_model=Optional[dict]) async def update_toolkit_valves_by_id( - id: str, form_data: dict, user=Depends(get_admin_user) + request: Request, id: str, form_data: dict, user=Depends(get_admin_user) ): toolkit = Tools.get_tool_by_id(id) if toolkit: - try: - valves = Tools.update_tool_valves_by_id(id, form_data) - return valves - except Exception as e: + if id in request.app.state.TOOLS: + toolkit_module = request.app.state.TOOLS[id] + else: + toolkit_module = load_toolkit_module_by_id(id) + request.app.state.TOOLS[id] = toolkit_module + + if hasattr(toolkit_module, "Valves"): + Valves = toolkit_module.Valves + + try: + valves = Valves(**form_data) + Tools.update_tool_valves_by_id(id, valves.model_dump()) + return valves.model_dump() + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, ) + else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts index c8607b091..2d5ad16b7 100644 --- a/src/lib/apis/functions/index.ts +++ b/src/lib/apis/functions/index.ts @@ -256,6 +256,38 @@ export const getFunctionValvesById = async (token: string, id: string) => { return res; }; +export const getFunctionValvesSpecById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/spec`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateFunctionValvesById = async (token: string, id: string, valves: object) => { let error = null; diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts index 25d543feb..28e8dde86 100644 --- a/src/lib/apis/tools/index.ts +++ b/src/lib/apis/tools/index.ts @@ -192,6 +192,105 @@ export const deleteToolById = async (token: string, id: string) => { return res; }; +export const getToolValvesById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getToolValvesSpecById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/spec`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateToolValvesById = async (token: string, id: string, valves: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...valves + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getUserValvesById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index 75e0ce4ff..fb92cc32d 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -24,6 +24,7 @@ import FunctionMenu from './Functions/FunctionMenu.svelte'; import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte'; import Switch from '../common/Switch.svelte'; + import ValvesModal from './ValvesModal.svelte'; const i18n = getContext('i18n'); @@ -33,6 +34,9 @@ let showConfirm = false; let query = ''; + let showValvesModal = false; + let selectedFunction = null; + const shareHandler = async (tool) => { console.log(tool); }; @@ -175,6 +179,10 @@
+ + { diff --git a/src/lib/components/workspace/Tools.svelte b/src/lib/components/workspace/Tools.svelte index 184f5942a..687117f99 100644 --- a/src/lib/components/workspace/Tools.svelte +++ b/src/lib/components/workspace/Tools.svelte @@ -20,6 +20,7 @@ import ConfirmDialog from '../common/ConfirmDialog.svelte'; import ToolMenu from './Tools/ToolMenu.svelte'; import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte'; + import ValvesModal from './ValvesModal.svelte'; const i18n = getContext('i18n'); @@ -29,6 +30,9 @@ let showConfirm = false; let query = ''; + let showValvesModal = false; + let selectedTool = null; + const shareHandler = async (tool) => { console.log(tool); }; @@ -169,6 +173,10 @@
+ + { diff --git a/src/lib/components/workspace/ValvesModal.svelte b/src/lib/components/workspace/ValvesModal.svelte index f67d948a9..8c15cdc64 100644 --- a/src/lib/components/workspace/ValvesModal.svelte +++ b/src/lib/components/workspace/ValvesModal.svelte @@ -5,6 +5,13 @@ import { addUser } from '$lib/apis/auths'; import Modal from '../common/Modal.svelte'; + import { + getFunctionValvesById, + getFunctionValvesSpecById, + updateFunctionValvesById + } from '$lib/apis/functions'; + import { getToolValvesById, getToolValvesSpecById, updateToolValvesById } from '$lib/apis/tools'; + import Spinner from '../common/Spinner.svelte'; const i18n = getContext('i18n'); const dispatch = createEventDispatcher(); @@ -14,21 +21,57 @@ export let type = 'tool'; export let id = null; + let saving = false; let loading = false; - let _user = { - name: '', - email: '', - password: '', - role: 'user' - }; + let valvesSpec = null; + let valves = {}; const submitHandler = async () => { - const stopLoading = () => { - dispatch('save'); - loading = false; - }; + saving = true; + + let res = null; + + if (type === 'tool') { + res = await updateToolValvesById(localStorage.token, id, valves).catch((error) => { + toast.error(error); + }); + } else if (type === 'function') { + res = await updateFunctionValvesById(localStorage.token, id, valves).catch((error) => { + toast.error(error); + }); + } + + if (res) { + toast.success('Valves updated successfully'); + } + + saving = false; }; + + const initHandler = async () => { + loading = true; + valves = {}; + valvesSpec = null; + + if (type === 'tool') { + valves = await getToolValvesById(localStorage.token, id); + valvesSpec = await getToolValvesSpecById(localStorage.token, id); + } else if (type === 'function') { + valves = await getFunctionValvesById(localStorage.token, id); + valvesSpec = await getFunctionValvesSpecById(localStorage.token, id); + } + + if (!valves) { + valves = {}; + } + + loading = false; + }; + + $: if (show) { + initHandler(); + } @@ -63,81 +106,81 @@ }} >
-
-
{$i18n.t('Role')}
+ {#if !loading} + {#if valvesSpec} + {#each Object.keys(valvesSpec.properties) as property, idx} +
+
+
+ {valvesSpec.properties[property].title} -
- -
-
+ {#if (valvesSpec?.required ?? []).includes(property)} + *required + {/if} +
-
-
{$i18n.t('Name')}
+ +
-
- -
-
+ {#if (valves[property] ?? null) !== null} +
+
+ +
+
+ {/if} -
- -
-
{$i18n.t('Email')}
- -
- -
-
- -
-
{$i18n.t('Password')}
- -
- -
-
+ {#if (valvesSpec.properties[property]?.description ?? null) !== null} +
+ {valvesSpec.properties[property].description} +
+ {/if} +
+ {/each} + {:else} +
No valves
+ {/if} + {:else} + + {/if}