diff --git a/backend/apps/webui/internal/migrations/018_add_function_is_global.py b/backend/apps/webui/internal/migrations/018_add_function_is_global.py
new file mode 100644
index 000000000..04cdab705
--- /dev/null
+++ b/backend/apps/webui/internal/migrations/018_add_function_is_global.py
@@ -0,0 +1,49 @@
+"""Peewee migrations -- 017_add_user_oauth_sub.py.
+
+Some examples (model - class or model name)::
+
+ > Model = migrator.orm['table_name'] # Return model in current state by name
+ > Model = migrator.ModelClass # Return model in current state by name
+
+ > migrator.sql(sql) # Run custom SQL
+ > migrator.run(func, *args, **kwargs) # Run python function with the given args
+ > migrator.create_model(Model) # Create a model (could be used as decorator)
+ > migrator.remove_model(model, cascade=True) # Remove a model
+ > migrator.add_fields(model, **fields) # Add fields to a model
+ > migrator.change_fields(model, **fields) # Change fields
+ > migrator.remove_fields(model, *field_names, cascade=True)
+ > migrator.rename_field(model, old_field_name, new_field_name)
+ > migrator.rename_table(model, new_table_name)
+ > migrator.add_index(model, *col_names, unique=False)
+ > migrator.add_not_null(model, *field_names)
+ > migrator.add_default(model, field_name, default)
+ > migrator.add_constraint(model, name, sql)
+ > migrator.drop_index(model, *col_names)
+ > migrator.drop_not_null(model, *field_names)
+ > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+ import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your migrations here."""
+
+ migrator.add_fields(
+ "function",
+ is_global=pw.BooleanField(default=False),
+ )
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your rollback migrations here."""
+
+ migrator.remove_fields("function", "is_global")
diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py
index 261987981..2cace54c4 100644
--- a/backend/apps/webui/models/functions.py
+++ b/backend/apps/webui/models/functions.py
@@ -30,6 +30,7 @@ class Function(Model):
meta = JSONField()
valves = JSONField()
is_active = BooleanField(default=False)
+ is_global = BooleanField(default=False)
updated_at = BigIntegerField()
created_at = BigIntegerField()
@@ -50,6 +51,7 @@ class FunctionModel(BaseModel):
content: str
meta: FunctionMeta
is_active: bool = False
+ is_global: bool = False
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@@ -66,6 +68,7 @@ class FunctionResponse(BaseModel):
name: str
meta: FunctionMeta
is_active: bool
+ is_global: bool
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@@ -144,6 +147,16 @@ class FunctionsTable:
for function in Function.select().where(Function.type == type)
]
+ def get_global_filter_functions(self) -> List[FunctionModel]:
+ return [
+ FunctionModel(**model_to_dict(function))
+ for function in Function.select().where(
+ Function.type == "filter",
+ Function.is_active == True,
+ Function.is_global == True,
+ )
+ ]
+
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try:
function = Function.get(Function.id == id)
diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py
index 4c89ca487..f01133a35 100644
--- a/backend/apps/webui/routers/functions.py
+++ b/backend/apps/webui/routers/functions.py
@@ -147,6 +147,33 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
)
+############################
+# ToggleGlobalById
+############################
+
+
+@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
+async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
+ function = Functions.get_function_by_id(id)
+ if function:
+ function = Functions.update_function_by_id(
+ id, {"is_global": not function.is_global}
+ )
+
+ if function:
+ return function
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
############################
# UpdateFunctionById
############################
diff --git a/backend/main.py b/backend/main.py
index d0e85ddda..aae305c5e 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -416,21 +416,23 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
return 0
- filter_ids = []
+ filter_ids = [
+ function.id for function in Functions.get_global_filter_functions()
+ ]
if "info" in model and "meta" in model["info"]:
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type(
- "filter", active_only=True
- )
- ]
- filter_ids = [
- filter_id
- for filter_id in enabled_filter_ids
- if filter_id in model["info"]["meta"].get("filterIds", [])
- ]
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
+ enabled_filter_ids = [
+ function.id
+ for function in Functions.get_functions_by_type(
+ "filter", active_only=True
+ )
+ ]
+ filter_ids = [
+ filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+ ]
+
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
@@ -919,7 +921,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
)
model = app.state.MODELS[model_id]
- print(model)
pipe = model.get("pipe")
if pipe:
@@ -1010,21 +1011,19 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
- filter_ids = []
+ filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type(
- "filter", active_only=True
- )
- ]
- filter_ids = [
- filter_id
- for filter_id in enabled_filter_ids
- if filter_id in model["info"]["meta"].get("filterIds", [])
- ]
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
+ enabled_filter_ids = [
+ function.id
+ for function in Functions.get_functions_by_type("filter", active_only=True)
+ ]
+ filter_ids = [
+ filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+ ]
+
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts
index 2d5ad16b7..ed3306b32 100644
--- a/src/lib/apis/functions/index.ts
+++ b/src/lib/apis/functions/index.ts
@@ -224,6 +224,38 @@ export const toggleFunctionById = async (token: string, id: string) => {
return res;
};
+export const toggleGlobalById = async (token: string, id: string) => {
+ let error = null;
+
+ const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle/global`, {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ authorization: `Bearer ${token}`
+ }
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .then((json) => {
+ return json;
+ })
+ .catch((err) => {
+ error = err.detail;
+
+ console.log(err);
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
export const getFunctionValvesById = async (token: string, id: string) => {
let error = null;
diff --git a/src/lib/components/icons/GlobeAlt.svelte b/src/lib/components/icons/GlobeAlt.svelte
new file mode 100644
index 000000000..d2f86f438
--- /dev/null
+++ b/src/lib/components/icons/GlobeAlt.svelte
@@ -0,0 +1,19 @@
+
+
+
diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte
index 8a8f88812..8c6389711 100644
--- a/src/lib/components/workspace/Functions.svelte
+++ b/src/lib/components/workspace/Functions.svelte
@@ -14,7 +14,8 @@
exportFunctions,
getFunctionById,
getFunctions,
- toggleFunctionById
+ toggleFunctionById,
+ toggleGlobalById
} from '$lib/apis/functions';
import ArrowDownTray from '../icons/ArrowDownTray.svelte';
@@ -113,6 +114,22 @@
models.set(await getModels(localStorage.token));
}
};
+
+ const toggleGlobalHandler = async (func) => {
+ const res = await toggleGlobalById(localStorage.token, func.id).catch((error) => {
+ toast.error(error);
+ });
+
+ if (res) {
+ if (func.is_global) {
+ toast.success($i18n.t('Filter is now globally enabled'));
+ } else {
+ toast.success($i18n.t('Filter is now globally disabled'));
+ }
+
+ functions.set(await getFunctions(localStorage.token));
+ }
+ };
@@ -259,6 +276,7 @@
{
goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
}}
@@ -275,6 +293,11 @@
selectedFunction = func;
showDeleteConfirm = true;
}}
+ toggleGlobalHandler={() => {
+ if (func.type === 'filter') {
+ toggleGlobalHandler(func);
+ }
+ }}
onClose={() => {}}
>