Support exporting users with data tool

This commit is contained in:
allegroai 2024-01-10 15:12:07 +02:00
parent 3752db122b
commit 811ab2bf4f

View File

@ -44,7 +44,7 @@ from apiserver.bll.task.param_utils import (
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.auth import Role
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@ -68,6 +68,7 @@ class PrePopulate:
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
users_filename = "users.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
img_source_regex = re.compile(
@ -80,6 +81,7 @@ class PrePopulate:
project_cls: Type[Project]
model_cls: Type[Model]
user_cls: Type[User]
auth_user_cls: Type[AuthUser]
# noinspection PyTypeChecker
@classmethod
@ -92,6 +94,8 @@ class PrePopulate:
cls.project_cls = cls._get_entity_type("database.model.project.Project")
if not hasattr(cls, "user_cls"):
cls.user_cls = cls._get_entity_type("database.model.User")
if not hasattr(cls, "auth_user_cls"):
cls.auth_user_cls = cls._get_entity_type("database.model.auth.User")
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
@ -207,6 +211,8 @@ class PrePopulate:
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
export_events: bool = True,
export_users: bool = False,
) -> Sequence[str]:
cls._init_entity_types()
@ -242,11 +248,15 @@ class PrePopulate:
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
if export_users:
cls._export_users(zfile)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
export_events=export_events,
cleanup_users=not export_users,
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
@ -267,6 +277,9 @@ class PrePopulate:
metadata_hash=metadata_hash,
)
if created_files:
print("Created files:\n" + "\n".join(file for file in created_files))
return created_files
@classmethod
@ -298,18 +311,26 @@ class PrePopulate:
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
user_mapping = cls._import_users(zfile, company_id)
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata)
cls._import(
zfile,
company_id=company_id,
user_id=user_id,
metadata=metadata,
user_mapping=user_mapping,
)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
@ -440,7 +461,7 @@ class PrePopulate:
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
entities = defaultdict(set)
entities: Dict[Any] = defaultdict(set)
if projects:
print("Reading projects...")
@ -499,7 +520,6 @@ class PrePopulate:
@classmethod
def _cleanup_model(cls, model: Model):
model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags)
@classmethod
@ -507,7 +527,6 @@ class PrePopulate:
task.comment = "Auto generated by Allegro.ai"
task.status_message = ""
task.status_reason = ""
task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
@ -515,17 +534,32 @@ class PrePopulate:
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
def _cleanup_auth_user(cls, user: AuthUser):
user.company = ""
for cred in user.credentials:
if getattr(cred, "company", None):
cred["company"] = ""
return user
@classmethod
def _cleanup_be_user(cls, user: User):
user.company = ""
user.preferences = None
return user
@classmethod
def _cleanup_entity(cls, entity_cls, entity, cleanup_users):
if cleanup_users:
entity.user = ""
if entity_cls == cls.task_cls:
cls._cleanup_task(entity)
elif entity_cls == cls.model_cls:
cls._cleanup_model(entity)
elif entity == cls.project_cls:
elif entity_cls == cls.project_cls:
cls._cleanup_project(entity)
@classmethod
@ -635,6 +669,38 @@ class PrePopulate:
else:
print(f"Artifact {full_path} not found")
@classmethod
def _export_users(cls, writer: ZipFile):
auth_users = {
user.id: cls._cleanup_auth_user(user)
for user in cls.auth_user_cls.objects(role__in=(Role.admin, Role.user))
}
if not auth_users:
return
be_users = {
user.id: cls._cleanup_be_user(user)
for user in cls.user_cls.objects(id__in=list(auth_users))
}
if not be_users:
return
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
print(f"Writing {len(auth_users)} users into {writer.filename}")
data = {}
for field, users in (("auth", auth_users), ("backend", be_users)):
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for user in users.values():
w.write(user.to_json())
data[field] = f.getvalue()
def get_field_bytes(k: str, v: bytes) -> bytes:
return f'"{k}": '.encode("utf-8") + v
data_str = b",\n".join(get_field_bytes(k, v) for k, v in data.items())
writer.writestr(cls.users_filename, b"{\n" + data_str + b"\n}")
@classmethod
def _get_base_filename(cls, cls_: type):
name = f"{cls_.__module__}.{cls_.__name__}"
@ -644,7 +710,13 @@ class PrePopulate:
@classmethod
def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
cls,
writer: ZipFile,
entities: dict,
hash_,
tag_entities: bool = False,
export_events: bool = True,
cleanup_users: bool = True,
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
@ -658,18 +730,19 @@ class PrePopulate:
if not items:
continue
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
if export_events:
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
cls._cleanup_entity(cls_, item)
cls._cleanup_entity(cls_, item, cleanup_users=cleanup_users)
w.write(item.to_json())
data = f.getvalue()
hash_.update(data)
@ -750,6 +823,68 @@ class PrePopulate:
)
return ids
@classmethod
def _import_users(cls, reader: ZipFile, company_id: str = "") -> dict:
"""
Import users to db and return the mapping of old user ids to the new ones
If no users were in the users file then the mapping was empty
If the user in the file has the same email as one of the existing ones then this user is skipped
and its id is mapped to the existing user with the same email
If the user with the same id exists in backend or auth db then its creation is skipped
"""
users_file = first(
fi for fi in reader.filelist if fi.orig_filename == cls.users_filename
)
if not users_file:
return {}
existing_user_ids = set(cls.user_cls.objects().scalar("id")) | set(
cls.auth_user_cls.objects().scalar("id")
)
existing_user_emails = {u.email: u.id for u in cls.auth_user_cls.objects()}
user_id_mappings = {}
with reader.open(users_file) as f:
data = json.loads(f.read())
auth_users = {u["_id"]: u for u in data["auth"]}
be_users = {u["_id"]: u for u in data["backend"]}
for uid, user in auth_users.items():
email = user.get("email")
existing_user_id = existing_user_emails.get(email)
if existing_user_id:
user_id_mappings[uid] = existing_user_id
continue
user_id_mappings[uid] = uid
if uid in existing_user_ids:
continue
credentials = user.get("credentials", [])
for c in credentials:
if c.get("company") == "":
c["company"] = company_id
if hasattr(cls.auth_user_cls, "sec_groups"):
user_role = user.get("role", Role.user)
if user_role == Role.user:
user["sec_groups"] = ["30795571-a470-4717-a80d-e8705fc776bf"]
else:
user["sec_groups"] = [
"c14a3cc6-1144-4896-8ea6-fb186ee19896",
"30795571-a470-4717-a80d-e8705fc776bf",
"30795571a4704717a80de8705897ytuyg",
]
auth_user = cls.auth_user_cls.from_json(json.dumps(user))
auth_user.company = company_id
auth_user.save()
be_user = cls.user_cls.from_json(json.dumps(be_users[uid]))
be_user.company = company_id
be_user.save()
return user_id_mappings
@classmethod
def _import(
cls,
@ -758,6 +893,7 @@ class PrePopulate:
user_id: str = None,
metadata: Mapping[str, Any] = None,
sort_tasks_by_last_updated: bool = True,
user_mapping: Mapping[str, str] = None,
):
"""
Import entities and events from the zip file
@ -768,7 +904,7 @@ class PrePopulate:
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
and fi.orig_filename not in (cls.metadata_filename, cls.users_filename)
]
metadata = metadata or {}
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
@ -778,7 +914,13 @@ class PrePopulate:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(
f, full_name, company_id, user_id, metadata, old_to_new_ids
f,
full_name=full_name,
company_id=company_id,
user_id=user_id,
metadata=metadata,
old_to_new_ids=old_to_new_ids,
user_mapping=user_mapping,
)
if res:
tasks = res
@ -799,7 +941,7 @@ class PrePopulate:
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, company_id, user_id, task.id)
cls._import_events(f, company_id, task.user, task.id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
@ -913,7 +1055,9 @@ class PrePopulate:
user_id: str,
metadata: Mapping[str, Any],
old_to_new_ids: Mapping[str, str] = None,
user_mapping: Mapping[str, str] = None,
) -> Optional[Sequence[Task]]:
user_mapping = user_mapping or {}
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
@ -935,7 +1079,7 @@ class PrePopulate:
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
doc.user = user_mapping.get(doc.user, user_id) if doc.user else user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, cls.project_cls):