mirror of
https://github.com/open-webui/open-webui
synced 2025-03-24 06:37:14 +00:00
feat: global filter
This commit is contained in:
parent
c8c85ba7fc
commit
edbd07f893
@ -0,0 +1,49 @@
|
||||
"""Peewee migrations -- 017_add_user_oauth_sub.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"function",
|
||||
is_global=pw.BooleanField(default=False),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("function", "is_global")
|
@ -30,6 +30,7 @@ class Function(Model):
|
||||
meta = JSONField()
|
||||
valves = JSONField()
|
||||
is_active = BooleanField(default=False)
|
||||
is_global = BooleanField(default=False)
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
@ -50,6 +51,7 @@ class FunctionModel(BaseModel):
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool = False
|
||||
is_global: bool = False
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
@ -66,6 +68,7 @@ class FunctionResponse(BaseModel):
|
||||
name: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool
|
||||
is_global: bool
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
@ -144,6 +147,16 @@ class FunctionsTable:
|
||||
for function in Function.select().where(Function.type == type)
|
||||
]
|
||||
|
||||
def get_global_filter_functions(self) -> List[FunctionModel]:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select().where(
|
||||
Function.type == "filter",
|
||||
Function.is_active == True,
|
||||
Function.is_global == True,
|
||||
)
|
||||
]
|
||||
|
||||
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
function = Function.get(Function.id == id)
|
||||
|
@ -147,6 +147,33 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ToggleGlobalById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
||||
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function = Functions.update_function_by_id(
|
||||
id, {"is_global": not function.is_global}
|
||||
)
|
||||
|
||||
if function:
|
||||
return function
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateFunctionById
|
||||
############################
|
||||
|
@ -416,21 +416,23 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
return 0
|
||||
|
||||
filter_ids = []
|
||||
filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id
|
||||
for filter_id in enabled_filter_ids
|
||||
if filter_id in model["info"]["meta"].get("filterIds", [])
|
||||
]
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
@ -919,7 +921,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
)
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
print(model)
|
||||
|
||||
pipe = model.get("pipe")
|
||||
if pipe:
|
||||
@ -1010,21 +1011,19 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = []
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id
|
||||
for filter_id in enabled_filter_ids
|
||||
if filter_id in model["info"]["meta"].get("filterIds", [])
|
||||
]
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
# Sort filter_ids by priority, using the get_priority function
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
|
@ -224,6 +224,38 @@ export const toggleFunctionById = async (token: string, id: string) => {
|
||||
return res;
|
||||
};
|
||||
|
||||
export const toggleGlobalById = async (token: string, id: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle/global`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
authorization: `Bearer ${token}`
|
||||
}
|
||||
})
|
||||
.then(async (res) => {
|
||||
if (!res.ok) throw await res.json();
|
||||
return res.json();
|
||||
})
|
||||
.then((json) => {
|
||||
return json;
|
||||
})
|
||||
.catch((err) => {
|
||||
error = err.detail;
|
||||
|
||||
console.log(err);
|
||||
return null;
|
||||
});
|
||||
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
export const getFunctionValvesById = async (token: string, id: string) => {
|
||||
let error = null;
|
||||
|
||||
|
19
src/lib/components/icons/GlobeAlt.svelte
Normal file
19
src/lib/components/icons/GlobeAlt.svelte
Normal file
@ -0,0 +1,19 @@
|
||||
<script lang="ts">
|
||||
export let className = 'w-4 h-4';
|
||||
export let strokeWidth = '1.5';
|
||||
</script>
|
||||
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width={strokeWidth}
|
||||
stroke="currentColor"
|
||||
class={className}
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 21a9.004 9.004 0 0 0 8.716-6.747M12 21a9.004 9.004 0 0 1-8.716-6.747M12 21c2.485 0 4.5-4.03 4.5-9S14.485 3 12 3m0 18c-2.485 0-4.5-4.03-4.5-9S9.515 3 12 3m0 0a8.997 8.997 0 0 1 7.843 4.582M12 3a8.997 8.997 0 0 0-7.843 4.582m15.686 0A11.953 11.953 0 0 1 12 10.5c-2.998 0-5.74-1.1-7.843-2.918m15.686 0A8.959 8.959 0 0 1 21 12c0 .778-.099 1.533-.284 2.253m0 0A17.919 17.919 0 0 1 12 16.5c-3.162 0-6.133-.815-8.716-2.247m0 0A9.015 9.015 0 0 1 3 12c0-1.605.42-3.113 1.157-4.418"
|
||||
/>
|
||||
</svg>
|
@ -14,7 +14,8 @@
|
||||
exportFunctions,
|
||||
getFunctionById,
|
||||
getFunctions,
|
||||
toggleFunctionById
|
||||
toggleFunctionById,
|
||||
toggleGlobalById
|
||||
} from '$lib/apis/functions';
|
||||
|
||||
import ArrowDownTray from '../icons/ArrowDownTray.svelte';
|
||||
@ -113,6 +114,22 @@
|
||||
models.set(await getModels(localStorage.token));
|
||||
}
|
||||
};
|
||||
|
||||
const toggleGlobalHandler = async (func) => {
|
||||
const res = await toggleGlobalById(localStorage.token, func.id).catch((error) => {
|
||||
toast.error(error);
|
||||
});
|
||||
|
||||
if (res) {
|
||||
if (func.is_global) {
|
||||
toast.success($i18n.t('Filter is now globally enabled'));
|
||||
} else {
|
||||
toast.success($i18n.t('Filter is now globally disabled'));
|
||||
}
|
||||
|
||||
functions.set(await getFunctions(localStorage.token));
|
||||
}
|
||||
};
|
||||
</script>
|
||||
|
||||
<svelte:head>
|
||||
@ -259,6 +276,7 @@
|
||||
</Tooltip>
|
||||
|
||||
<FunctionMenu
|
||||
{func}
|
||||
editHandler={() => {
|
||||
goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
|
||||
}}
|
||||
@ -275,6 +293,11 @@
|
||||
selectedFunction = func;
|
||||
showDeleteConfirm = true;
|
||||
}}
|
||||
toggleGlobalHandler={() => {
|
||||
if (func.type === 'filter') {
|
||||
toggleGlobalHandler(func);
|
||||
}
|
||||
}}
|
||||
onClose={() => {}}
|
||||
>
|
||||
<button
|
||||
|
@ -5,21 +5,24 @@
|
||||
|
||||
import Dropdown from '$lib/components/common/Dropdown.svelte';
|
||||
import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
|
||||
import Pencil from '$lib/components/icons/Pencil.svelte';
|
||||
import Tooltip from '$lib/components/common/Tooltip.svelte';
|
||||
import Tags from '$lib/components/chat/Tags.svelte';
|
||||
import Share from '$lib/components/icons/Share.svelte';
|
||||
import ArchiveBox from '$lib/components/icons/ArchiveBox.svelte';
|
||||
import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte';
|
||||
import ArrowDownTray from '$lib/components/icons/ArrowDownTray.svelte';
|
||||
import Switch from '$lib/components/common/Switch.svelte';
|
||||
import GlobeAlt from '$lib/components/icons/GlobeAlt.svelte';
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
export let func;
|
||||
|
||||
export let editHandler: Function;
|
||||
export let shareHandler: Function;
|
||||
export let cloneHandler: Function;
|
||||
export let exportHandler: Function;
|
||||
export let deleteHandler: Function;
|
||||
export let toggleGlobalHandler: Function;
|
||||
|
||||
export let onClose: Function;
|
||||
|
||||
let show = false;
|
||||
@ -45,6 +48,24 @@
|
||||
align="start"
|
||||
transition={flyAndScale}
|
||||
>
|
||||
{#if func.type === 'filter'}
|
||||
<div
|
||||
class="flex gap-2 justify-between items-center px-3 py-2 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
|
||||
>
|
||||
<div class="flex gap-2 items-center">
|
||||
<GlobeAlt />
|
||||
|
||||
<div class="flex items-center">{$i18n.t('Global')}</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Switch on:change={toggleGlobalHandler} bind:state={func.is_global} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr class="border-gray-100 dark:border-gray-800 my-1" />
|
||||
{/if}
|
||||
|
||||
<DropdownMenu.Item
|
||||
class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 rounded-md"
|
||||
on:click={() => {
|
||||
|
Loading…
Reference in New Issue
Block a user