enh: knowledge access control

This commit is contained in:
Timothy Jaeryang Baek 2024-11-16 16:51:55 -08:00
parent 8da24d81a4
commit 227cca35e8
23 changed files with 241 additions and 149 deletions

View File

@ -36,7 +36,9 @@ from open_webui.utils.payload import (
apply_model_system_prompt_to_body,
)
from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])

View File

@ -13,6 +13,7 @@ from open_webui.apps.webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -129,7 +130,7 @@ class KnowledgeTable:
except Exception:
return None
def get_knowledge_items(self) -> list[KnowledgeModel]:
def get_knowledge_bases(self) -> list[KnowledgeModel]:
with get_db() as db:
return [
KnowledgeModel.model_validate(knowledge)
@ -138,6 +139,17 @@ class KnowledgeTable:
.all()
]
def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[KnowledgeModel]:
knowledge_bases = self.get_knowledge_bases()
return [
knowledge_base
for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id
or has_access(user_id, permission, knowledge_base.access_control)
]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try:
with get_db() as db:

View File

@ -15,7 +15,7 @@ from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.utils import has_access
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)

View File

@ -26,64 +26,98 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# GetKnowledgeItems
# getKnowledgeBases
############################
@router.get(
"/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]]
)
async def get_knowledge_items(
id: Optional[str] = None, user=Depends(get_verified_user)
):
if id:
knowledge = Knowledges.get_knowledge_by_id(id=id)
@router.get("/", response_model=list[KnowledgeResponse])
async def get_knowledge(user=Depends(get_verified_user)):
knowledge_bases = []
if knowledge:
return knowledge
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if user.role == "admin":
knowledge_bases = Knowledges.get_knowledge_bases()
else:
knowledge_bases = []
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
for knowledge in Knowledges.get_knowledge_items():
files = []
if knowledge.data:
files = Files.get_file_metadatas_by_ids(
knowledge.data.get("file_ids", [])
)
# Check if all files exist
if len(files) != len(knowledge.data.get("file_ids", [])):
missing_files = list(
set(knowledge.data.get("file_ids", []))
- set([file.id for file in files])
)
if missing_files:
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
for missing_file in missing_files:
file_ids.remove(missing_file)
data["file_ids"] = file_ids
Knowledges.update_knowledge_by_id(
id=knowledge.id, form_data=KnowledgeUpdateForm(data=data)
)
files = Files.get_file_metadatas_by_ids(file_ids)
knowledge_bases.append(
KnowledgeResponse(
**knowledge.model_dump(),
files=files,
)
# Get files for each knowledge base
for knowledge_base in knowledge_bases:
files = []
if knowledge_base.data:
files = Files.get_file_metadatas_by_ids(
knowledge_base.data.get("file_ids", [])
)
return knowledge_bases
# Check if all files exist
if len(files) != len(knowledge_base.data.get("file_ids", [])):
missing_files = list(
set(knowledge_base.data.get("file_ids", []))
- set([file.id for file in files])
)
if missing_files:
data = knowledge_base.data or {}
file_ids = data.get("file_ids", [])
for missing_file in missing_files:
file_ids.remove(missing_file)
data["file_ids"] = file_ids
Knowledges.update_knowledge_by_id(
id=knowledge_base.id, form_data=KnowledgeUpdateForm(data=data)
)
files = Files.get_file_metadatas_by_ids(file_ids)
knowledge_base = KnowledgeResponse(
**knowledge_base.model_dump(),
files=files,
)
return knowledge_bases
@router.get("/list", response_model=list[KnowledgeResponse])
async def get_knowledge_list(user=Depends(get_verified_user)):
knowledge_bases = []
if user.role == "admin":
knowledge_bases = Knowledges.get_knowledge_bases()
else:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
# Get files for each knowledge base
for knowledge_base in knowledge_bases:
files = []
if knowledge_base.data:
files = Files.get_file_metadatas_by_ids(
knowledge_base.data.get("file_ids", [])
)
# Check if all files exist
if len(files) != len(knowledge_base.data.get("file_ids", [])):
missing_files = list(
set(knowledge_base.data.get("file_ids", []))
- set([file.id for file in files])
)
if missing_files:
data = knowledge_base.data or {}
file_ids = data.get("file_ids", [])
for missing_file in missing_files:
file_ids.remove(missing_file)
data["file_ids"] = file_ids
Knowledges.update_knowledge_by_id(
id=knowledge_base.id, form_data=KnowledgeUpdateForm(data=data)
)
files = Files.get_file_metadatas_by_ids(file_ids)
knowledge_base = KnowledgeResponse(
**knowledge_base.model_dump(),
files=files,
)
return knowledge_bases
############################
@ -92,7 +126,9 @@ async def get_knowledge_items(
@router.post("/create", response_model=Optional[KnowledgeResponse])
async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)):
async def create_new_knowledge(
form_data: KnowledgeForm, user=Depends(get_verified_user)
):
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
if knowledge:
@ -141,7 +177,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
async def update_knowledge_by_id(
id: str,
form_data: KnowledgeUpdateForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
@ -173,7 +209,7 @@ class KnowledgeFileIdForm(BaseModel):
def add_file_to_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
file = Files.get_file_by_id(form_data.file_id)
@ -238,7 +274,7 @@ def add_file_to_knowledge_by_id(
def update_file_from_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
file = Files.get_file_by_id(form_data.file_id)
@ -288,7 +324,7 @@ def update_file_from_knowledge_by_id(
def remove_file_from_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
file = Files.get_file_by_id(form_data.file_id)
@ -371,7 +407,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_admin_user)):
@router.delete("/{id}/delete", response_model=bool)
async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)):
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
try:
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
except Exception as e:

View File

@ -10,7 +10,9 @@ from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
router = APIRouter()

View File

@ -134,8 +134,8 @@ from open_webui.utils.utils import (
get_current_user,
get_http_authorization_cred,
get_verified_user,
has_access,
)
from open_webui.utils.access_control import has_access
if SAFE_MODE:
print("SAFE MODE ENABLED")

View File

@ -0,0 +1,57 @@
from typing import Optional, Union, List, Dict
from open_webui.apps.webui.models.groups import Groups
def has_permission(
user_id: str,
permission_key: str,
default_permissions: Dict[str, bool] = {},
) -> bool:
"""
Check if a user has a specific permission by checking the group permissions
and falls back to default permissions if not found in any group.
Permission keys can be hierarchical and separated by dots ('.').
"""
def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
"""Traverse permissions dict using a list of keys (from dot-split permission_key)."""
for key in keys:
if key not in permissions:
return False # If any part of the hierarchy is missing, deny access
permissions = permissions[key] # Go one level deeper
return bool(permissions) # Return the boolean at the final level
permission_hierarchy = permission_key.split(".")
# Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id)
for group in user_groups:
group_permissions = group.permissions
if get_permission(group_permissions, permission_hierarchy):
return True
# Check default permissions afterwards if the group permissions don't allow it
return get_permission(default_permissions, permission_hierarchy)
def has_access(
user_id: str,
type: str = "write",
access_control: Optional[dict] = None,
) -> bool:
print("user_id", user_id, "type", type, "access_control", access_control)
if access_control is None:
return type == "read"
user_groups = Groups.get_groups_by_member_id(user_id)
user_group_ids = [group.id for group in user_groups]
permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", [])
permitted_user_ids = permission_access.get("user_ids", [])
return user_id in permitted_user_ids or any(
group_id in permitted_group_ids for group_id in user_group_ids
)

View File

@ -7,7 +7,6 @@ from typing import Optional, Union, List, Dict
from open_webui.apps.webui.models.users import Users
from open_webui.apps.webui.models.groups import Groups
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY
@ -153,58 +152,3 @@ def get_admin_user(user=Depends(get_current_user)):
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return user
def has_permission(
user_id: str,
permission_key: str,
default_permissions: Dict[str, bool] = {},
) -> bool:
"""
Check if a user has a specific permission by checking the group permissions
and falls back to default permissions if not found in any group.
Permission keys can be hierarchical and separated by dots ('.').
"""
def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
"""Traverse permissions dict using a list of keys (from dot-split permission_key)."""
for key in keys:
if key not in permissions:
return False # If any part of the hierarchy is missing, deny access
permissions = permissions[key] # Go one level deeper
return bool(permissions) # Return the boolean at the final level
permission_hierarchy = permission_key.split(".")
# Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id)
for group in user_groups:
group_permissions = group.permissions
if get_permission(group_permissions, permission_hierarchy):
return True
# Check default permissions afterwards if the group permissions don't allow it
return get_permission(default_permissions, permission_hierarchy)
def has_access(
user_id: str,
type: str = "write",
access_control: Optional[dict] = None,
) -> bool:
print("user_id", user_id, "type", type, "access_control", access_control)
if access_control is None:
return type == "read"
user_groups = Groups.get_groups_by_member_id(user_id)
user_group_ids = [group.id for group in user_groups]
permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", [])
permitted_user_ids = permission_access.get("user_ids", [])
return user_id in permitted_user_ids or any(
group_id in permitted_group_ids for group_id in user_group_ids
)

View File

@ -32,7 +32,7 @@ export const createNewKnowledge = async (token: string, name: string, descriptio
return res;
};
export const getKnowledgeItems = async (token: string = '') => {
export const getKnowledgeBases = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/`, {
@ -63,6 +63,37 @@ export const getKnowledgeItems = async (token: string = '') => {
return res;
};
export const getKnowledgeBaseList = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/list`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getKnowledgeById = async (token: string, id: string) => {
let error = null;

View File

@ -19,7 +19,7 @@
} from '$lib/apis/retrieval';
import { knowledge, models } from '$lib/stores';
import { getKnowledgeItems } from '$lib/apis/knowledge';
import { getKnowledgeBases } from '$lib/apis/knowledge';
import { uploadDir, deleteAllFiles, deleteFileById } from '$lib/apis/files';
import ResetUploadDirConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';

View File

@ -17,9 +17,11 @@
user as _user,
showControls
} from '$lib/stores';
import { blobToFile, findWordIndices } from '$lib/utils';
import { transcribeAudio } from '$lib/apis/audio';
import { uploadFile } from '$lib/apis/files';
import { getTools } from '$lib/apis/tools';
import { WEBUI_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
@ -32,7 +34,6 @@
import Commands from './MessageInput/Commands.svelte';
import XMark from '../icons/XMark.svelte';
import RichTextInput from '../common/RichTextInput.svelte';
import { getTools } from '$lib/apis/tools';
const i18n = getContext('i18n');

View File

@ -8,7 +8,7 @@
import { removeLastWordFromString } from '$lib/utils';
import { getPrompts } from '$lib/apis/prompts';
import { getKnowledgeItems } from '$lib/apis/knowledge';
import { getKnowledgeBases } from '$lib/apis/knowledge';
import Prompts from './Commands/Prompts.svelte';
import Knowledge from './Commands/Knowledge.svelte';
@ -46,7 +46,7 @@
prompts.set(await getPrompts(localStorage.token));
})(),
(async () => {
knowledge.set(await getKnowledgeItems(localStorage.token));
knowledge.set(await getKnowledgeBases(localStorage.token));
})()
]);
loading = false;

View File

@ -10,7 +10,11 @@
const i18n = getContext('i18n');
import { WEBUI_NAME, knowledge } from '$lib/stores';
import { getKnowledgeItems, deleteKnowledgeById } from '$lib/apis/knowledge';
import {
getKnowledgeBases,
deleteKnowledgeById,
getKnowledgeBaseList
} from '$lib/apis/knowledge';
import { goto } from '$app/navigation';
@ -26,13 +30,21 @@
let fuse = null;
let knowledgeBases = [];
let filteredItems = [];
$: if (knowledgeBases) {
fuse = new Fuse(knowledgeBases, {
keys: ['name', 'description']
});
}
$: if (fuse) {
filteredItems = query
? fuse.search(query).map((e) => {
return e.item;
})
: $knowledge;
: knowledgeBases;
}
const deleteHandler = async (item) => {
@ -41,19 +53,14 @@
});
if (res) {
knowledge.set(await getKnowledgeItems(localStorage.token));
knowledgeBases = await getKnowledgeBaseList(localStorage.token);
knowledge.set(await getKnowledgeBases(localStorage.token));
toast.success($i18n.t('Knowledge deleted successfully.'));
}
};
onMount(async () => {
knowledge.set(await getKnowledgeItems(localStorage.token));
knowledge.subscribe((value) => {
fuse = new Fuse(value, {
keys: ['name', 'description']
});
});
knowledgeBases = await getKnowledgeBaseList(localStorage.token);
});
</script>

View File

@ -3,7 +3,7 @@
import { getContext } from 'svelte';
const i18n = getContext('i18n');
import { createNewKnowledge, getKnowledgeItems } from '$lib/apis/knowledge';
import { createNewKnowledge, getKnowledgeBases } from '$lib/apis/knowledge';
import { toast } from 'svelte-sonner';
import { knowledge } from '$lib/stores';
import AccessControl from '../common/AccessControl.svelte';
@ -30,7 +30,7 @@
if (res) {
toast.success($i18n.t('Knowledge created successfully.'));
knowledge.set(await getKnowledgeItems(localStorage.token));
knowledge.set(await getKnowledgeBases(localStorage.token));
goto(`/workspace/knowledge/${res.id}`);
}

View File

@ -15,7 +15,7 @@
import {
addFileToKnowledgeById,
getKnowledgeById,
getKnowledgeItems,
getKnowledgeBases,
removeFileFromKnowledgeById,
resetKnowledgeById,
updateFileFromKnowledgeById,
@ -27,11 +27,11 @@
import { processFile } from '$lib/apis/retrieval';
import Spinner from '$lib/components/common/Spinner.svelte';
import Files from './Collection/Files.svelte';
import Files from './KnowledgeBase/Files.svelte';
import AddFilesPlaceholder from '$lib/components/AddFilesPlaceholder.svelte';
import AddContentMenu from './Collection/AddContentMenu.svelte';
import AddTextContentModal from './Collection/AddTextContentModal.svelte';
import AddContentMenu from './KnowledgeBase/AddContentMenu.svelte';
import AddTextContentModal from './KnowledgeBase/AddTextContentModal.svelte';
import SyncConfirmDialog from '../../common/ConfirmDialog.svelte';
import RichTextInput from '$lib/components/common/RichTextInput.svelte';
@ -428,7 +428,7 @@
if (res) {
toast.success($i18n.t('Knowledge updated successfully'));
_knowledge.set(await getKnowledgeItems(localStorage.token));
_knowledge.set(await getKnowledgeBases(localStorage.token));
}
}, 1000);
};

View File

@ -12,7 +12,7 @@
import Textarea from '$lib/components/common/Textarea.svelte';
import { getTools } from '$lib/apis/tools';
import { getFunctions } from '$lib/apis/functions';
import { getKnowledgeItems } from '$lib/apis/knowledge';
import { getKnowledgeBases } from '$lib/apis/knowledge';
import AccessControl from '../common/AccessControl.svelte';
import { stringify } from 'postcss';
@ -151,7 +151,7 @@
onMount(async () => {
await tools.set(await getTools(localStorage.token));
await functions.set(await getFunctions(localStorage.token));
await knowledgeCollections.set(await getKnowledgeItems(localStorage.token));
await knowledgeCollections.set(await getKnowledgeBases(localStorage.token));
// Scroll to top 'workspace-container' element
const workspaceContainer = document.getElementById('workspace-container');

View File

@ -10,7 +10,7 @@
import { page } from '$app/stores';
import { fade } from 'svelte/transition';
import { getKnowledgeItems } from '$lib/apis/knowledge';
import { getKnowledgeBases } from '$lib/apis/knowledge';
import { getFunctions } from '$lib/apis/functions';
import { getModels, getVersionUpdates } from '$lib/apis';
import { getAllTags } from '$lib/apis/chats';

View File

@ -2,13 +2,13 @@
import { onMount } from 'svelte';
import { knowledge } from '$lib/stores';
import { getKnowledgeItems } from '$lib/apis/knowledge';
import { getKnowledgeBases } from '$lib/apis/knowledge';
import Knowledge from '$lib/components/workspace/Knowledge.svelte';
onMount(async () => {
await Promise.all([
(async () => {
knowledge.set(await getKnowledgeItems(localStorage.token));
knowledge.set(await getKnowledgeBases(localStorage.token));
})()
]);
});

View File

@ -1,5 +1,5 @@
<script>
import Collection from '$lib/components/workspace/Knowledge/Collection.svelte';
import KnowledgeBase from '$lib/components/workspace/Knowledge/KnowledgeBase.svelte';
</script>
<Collection />
<KnowledgeBase />

View File

@ -1,5 +1,5 @@
<script>
import CreateCollection from '$lib/components/workspace/Knowledge/CreateCollection.svelte';
import CreateKnowledgeBase from '$lib/components/workspace/Knowledge/CreateKnowledgeBase.svelte';
</script>
<CreateCollection />
<CreateKnowledgeBase />