From edbd07f893dcf6fdd996a098a74264f106e4129f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 27 Jun 2024 13:04:12 -0700 Subject: [PATCH] feat: global filter --- .../migrations/018_add_function_is_global.py | 49 +++++++++++++++++++ backend/apps/webui/models/functions.py | 13 +++++ backend/apps/webui/routers/functions.py | 27 ++++++++++ backend/main.py | 49 +++++++++---------- src/lib/apis/functions/index.ts | 32 ++++++++++++ src/lib/components/icons/GlobeAlt.svelte | 19 +++++++ src/lib/components/workspace/Functions.svelte | 25 +++++++++- .../workspace/Functions/FunctionMenu.svelte | 27 ++++++++-- 8 files changed, 212 insertions(+), 29 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/018_add_function_is_global.py create mode 100644 src/lib/components/icons/GlobeAlt.svelte diff --git a/backend/apps/webui/internal/migrations/018_add_function_is_global.py b/backend/apps/webui/internal/migrations/018_add_function_is_global.py new file mode 100644 index 000000000..04cdab705 --- /dev/null +++ b/backend/apps/webui/internal/migrations/018_add_function_is_global.py @@ -0,0 +1,49 @@ +"""Peewee migrations -- 017_add_user_oauth_sub.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "function", + is_global=pw.BooleanField(default=False), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("function", "is_global") diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 261987981..2cace54c4 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -30,6 +30,7 @@ class Function(Model): meta = JSONField() valves = JSONField() is_active = BooleanField(default=False) + is_global = BooleanField(default=False) updated_at = BigIntegerField() created_at = BigIntegerField() @@ -50,6 +51,7 @@ class FunctionModel(BaseModel): content: str meta: FunctionMeta is_active: bool = False + is_global: bool = False updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -66,6 +68,7 @@ class FunctionResponse(BaseModel): name: str meta: FunctionMeta is_active: bool + is_global: bool updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -144,6 +147,16 @@ class FunctionsTable: for function in Function.select().where(Function.type == type) ] + def get_global_filter_functions(self) -> List[FunctionModel]: + return [ + FunctionModel(**model_to_dict(function)) + for function in Function.select().where( + Function.type == "filter", + Function.is_active == True, + Function.is_global == True, + ) + ] + def get_function_valves_by_id(self, id: str) -> Optional[dict]: try: function = Function.get(Function.id == id) diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index 4c89ca487..f01133a35 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -147,6 +147,33 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): ) +############################ +# ToggleGlobalById +############################ + + +@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) +async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + if function: + function = Functions.update_function_by_id( + id, {"is_global": not function.is_global} + ) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # UpdateFunctionById ############################ diff --git a/backend/main.py b/backend/main.py index d0e85ddda..aae305c5e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -416,21 +416,23 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) return 0 - filter_ids = [] + filter_ids = [ + function.id for function in Functions.get_global_filter_functions() + ] if "info" in model and "meta" in model["info"]: - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type( - "filter", active_only=True - ) - ] - filter_ids = [ - filter_id - for filter_id in enabled_filter_ids - if filter_id in model["info"]["meta"].get("filterIds", []) - ] + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids = list(set(filter_ids)) + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type( + "filter", active_only=True + ) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + filter_ids.sort(key=get_priority) for filter_id in filter_ids: filter = Functions.get_function_by_id(filter_id) @@ -919,7 +921,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) model = app.state.MODELS[model_id] - print(model) pipe = model.get("pipe") if pipe: @@ -1010,21 +1011,19 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): return (function.valves if function.valves else {}).get("priority", 0) return 0 - filter_ids = [] + filter_ids = [function.id for function in Functions.get_global_filter_functions()] if "info" in model and "meta" in model["info"]: - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type( - "filter", active_only=True - ) - ] - filter_ids = [ - filter_id - for filter_id in enabled_filter_ids - if filter_id in model["info"]["meta"].get("filterIds", []) - ] + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids = list(set(filter_ids)) + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + # Sort filter_ids by priority, using the get_priority function filter_ids.sort(key=get_priority) diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts index 2d5ad16b7..ed3306b32 100644 --- a/src/lib/apis/functions/index.ts +++ b/src/lib/apis/functions/index.ts @@ -224,6 +224,38 @@ export const toggleFunctionById = async (token: string, id: string) => { return res; }; +export const toggleGlobalById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle/global`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getFunctionValvesById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/icons/GlobeAlt.svelte b/src/lib/components/icons/GlobeAlt.svelte new file mode 100644 index 000000000..d2f86f438 --- /dev/null +++ b/src/lib/components/icons/GlobeAlt.svelte @@ -0,0 +1,19 @@ + + + + + diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index 8a8f88812..8c6389711 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -14,7 +14,8 @@ exportFunctions, getFunctionById, getFunctions, - toggleFunctionById + toggleFunctionById, + toggleGlobalById } from '$lib/apis/functions'; import ArrowDownTray from '../icons/ArrowDownTray.svelte'; @@ -113,6 +114,22 @@ models.set(await getModels(localStorage.token)); } }; + + const toggleGlobalHandler = async (func) => { + const res = await toggleGlobalById(localStorage.token, func.id).catch((error) => { + toast.error(error); + }); + + if (res) { + if (func.is_global) { + toast.success($i18n.t('Filter is now globally enabled')); + } else { + toast.success($i18n.t('Filter is now globally disabled')); + } + + functions.set(await getFunctions(localStorage.token)); + } + }; @@ -259,6 +276,7 @@ { goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`); }} @@ -275,6 +293,11 @@ selectedFunction = func; showDeleteConfirm = true; }} + toggleGlobalHandler={() => { + if (func.type === 'filter') { + toggleGlobalHandler(func); + } + }} onClose={() => {}} >