Support exporting users with data tool

This commit is contained in:
allegroai 2024-01-10 15:12:07 +02:00
parent 3752db122b
commit 811ab2bf4f

View File

@ -44,7 +44,7 @@ from apiserver.bll.task.param_utils import (
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.config.info import get_default_company from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.auth import Role from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import ( from apiserver.database.model.task.task import (
@ -68,6 +68,7 @@ class PrePopulate:
export_tag_prefix = "Exported:" export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S" export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json" metadata_filename = "metadata.json"
users_filename = "users.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2) zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts" artifacts_ext = ".artifacts"
img_source_regex = re.compile( img_source_regex = re.compile(
@ -80,6 +81,7 @@ class PrePopulate:
project_cls: Type[Project] project_cls: Type[Project]
model_cls: Type[Model] model_cls: Type[Model]
user_cls: Type[User] user_cls: Type[User]
auth_user_cls: Type[AuthUser]
# noinspection PyTypeChecker # noinspection PyTypeChecker
@classmethod @classmethod
@ -92,6 +94,8 @@ class PrePopulate:
cls.project_cls = cls._get_entity_type("database.model.project.Project") cls.project_cls = cls._get_entity_type("database.model.project.Project")
if not hasattr(cls, "user_cls"): if not hasattr(cls, "user_cls"):
cls.user_cls = cls._get_entity_type("database.model.User") cls.user_cls = cls._get_entity_type("database.model.User")
if not hasattr(cls, "auth_user_cls"):
cls.auth_user_cls = cls._get_entity_type("database.model.auth.User")
class JsonLinesWriter: class JsonLinesWriter:
def __init__(self, file: BinaryIO): def __init__(self, file: BinaryIO):
@ -207,6 +211,8 @@ class PrePopulate:
task_statuses: Sequence[str] = None, task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False, tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None, metadata: Mapping[str, Any] = None,
export_events: bool = True,
export_users: bool = False,
) -> Sequence[str]: ) -> Sequence[str]:
cls._init_entity_types() cls._init_entity_types()
@ -242,11 +248,15 @@ class PrePopulate:
with ZipFile(file, **cls.zip_args) as zfile: with ZipFile(file, **cls.zip_args) as zfile:
if metadata: if metadata:
zfile.writestr(cls.metadata_filename, meta_str) zfile.writestr(cls.metadata_filename, meta_str)
if export_users:
cls._export_users(zfile)
artifacts = cls._export( artifacts = cls._export(
zfile, zfile,
entities=entities, entities=entities,
hash_=hash_, hash_=hash_,
tag_entities=tag_exported_entities, tag_entities=tag_exported_entities,
export_events=export_events,
cleanup_users=not export_users,
) )
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}") file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
@ -267,6 +277,9 @@ class PrePopulate:
metadata_hash=metadata_hash, metadata_hash=metadata_hash,
) )
if created_files:
print("Created files:\n" + "\n".join(file for file in created_files))
return created_files return created_files
@classmethod @classmethod
@ -298,18 +311,26 @@ class PrePopulate:
except Exception: except Exception:
pass pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID # Make sure we won't end up with an invalid company ID
if company_id is None: if company_id is None:
company_id = "" company_id = ""
user_mapping = cls._import_users(zfile, company_id)
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
existing_user = cls.user_cls.objects(id=user_id).only("id").first() existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user: if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save() cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata) cls._import(
zfile,
company_id=company_id,
user_id=user_id,
metadata=metadata,
user_mapping=user_mapping,
)
if artifacts_path and os.path.isdir(artifacts_path): if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext) artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
@ -440,7 +461,7 @@ class PrePopulate:
projects: Sequence[str] = None, projects: Sequence[str] = None,
task_statuses: Sequence[str] = None, task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]: ) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
entities = defaultdict(set) entities: Dict[Any] = defaultdict(set)
if projects: if projects:
print("Reading projects...") print("Reading projects...")
@ -499,7 +520,6 @@ class PrePopulate:
@classmethod @classmethod
def _cleanup_model(cls, model: Model): def _cleanup_model(cls, model: Model):
model.company = "" model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags) model.tags = cls._filter_out_export_tags(model.tags)
@classmethod @classmethod
@ -507,7 +527,6 @@ class PrePopulate:
task.comment = "Auto generated by Allegro.ai" task.comment = "Auto generated by Allegro.ai"
task.status_message = "" task.status_message = ""
task.status_reason = "" task.status_reason = ""
task.user = ""
task.company = "" task.company = ""
task.tags = cls._filter_out_export_tags(task.tags) task.tags = cls._filter_out_export_tags(task.tags)
if task.output: if task.output:
@ -515,17 +534,32 @@ class PrePopulate:
@classmethod @classmethod
def _cleanup_project(cls, project: Project): def _cleanup_project(cls, project: Project):
project.user = ""
project.company = "" project.company = ""
project.tags = cls._filter_out_export_tags(project.tags) project.tags = cls._filter_out_export_tags(project.tags)
@classmethod @classmethod
def _cleanup_entity(cls, entity_cls, entity): def _cleanup_auth_user(cls, user: AuthUser):
user.company = ""
for cred in user.credentials:
if getattr(cred, "company", None):
cred["company"] = ""
return user
@classmethod
def _cleanup_be_user(cls, user: User):
user.company = ""
user.preferences = None
return user
@classmethod
def _cleanup_entity(cls, entity_cls, entity, cleanup_users):
if cleanup_users:
entity.user = ""
if entity_cls == cls.task_cls: if entity_cls == cls.task_cls:
cls._cleanup_task(entity) cls._cleanup_task(entity)
elif entity_cls == cls.model_cls: elif entity_cls == cls.model_cls:
cls._cleanup_model(entity) cls._cleanup_model(entity)
elif entity == cls.project_cls: elif entity_cls == cls.project_cls:
cls._cleanup_project(entity) cls._cleanup_project(entity)
@classmethod @classmethod
@ -635,6 +669,38 @@ class PrePopulate:
else: else:
print(f"Artifact {full_path} not found") print(f"Artifact {full_path} not found")
@classmethod
def _export_users(cls, writer: ZipFile):
auth_users = {
user.id: cls._cleanup_auth_user(user)
for user in cls.auth_user_cls.objects(role__in=(Role.admin, Role.user))
}
if not auth_users:
return
be_users = {
user.id: cls._cleanup_be_user(user)
for user in cls.user_cls.objects(id__in=list(auth_users))
}
if not be_users:
return
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
print(f"Writing {len(auth_users)} users into {writer.filename}")
data = {}
for field, users in (("auth", auth_users), ("backend", be_users)):
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for user in users.values():
w.write(user.to_json())
data[field] = f.getvalue()
def get_field_bytes(k: str, v: bytes) -> bytes:
return f'"{k}": '.encode("utf-8") + v
data_str = b",\n".join(get_field_bytes(k, v) for k, v in data.items())
writer.writestr(cls.users_filename, b"{\n" + data_str + b"\n}")
@classmethod @classmethod
def _get_base_filename(cls, cls_: type): def _get_base_filename(cls, cls_: type):
name = f"{cls_.__module__}.{cls_.__name__}" name = f"{cls_.__module__}.{cls_.__name__}"
@ -644,7 +710,13 @@ class PrePopulate:
@classmethod @classmethod
def _export( def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False cls,
writer: ZipFile,
entities: dict,
hash_,
tag_entities: bool = False,
export_events: bool = True,
cleanup_users: bool = True,
) -> Sequence[str]: ) -> Sequence[str]:
""" """
Export the requested experiments, projects and models and return the list of artifact files Export the requested experiments, projects and models and return the list of artifact files
@ -658,18 +730,19 @@ class PrePopulate:
if not items: if not items:
continue continue
base_filename = cls._get_base_filename(cls_) base_filename = cls._get_base_filename(cls_)
for item in items: if export_events:
artifacts.extend( for item in items:
cls._export_entity_related_data( artifacts.extend(
cls_, item, base_filename, writer, hash_ cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
) )
)
filename = base_filename + ".json" filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}") print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with BytesIO() as f: with BytesIO() as f:
with cls.JsonLinesWriter(f) as w: with cls.JsonLinesWriter(f) as w:
for item in items: for item in items:
cls._cleanup_entity(cls_, item) cls._cleanup_entity(cls_, item, cleanup_users=cleanup_users)
w.write(item.to_json()) w.write(item.to_json())
data = f.getvalue() data = f.getvalue()
hash_.update(data) hash_.update(data)
@ -750,6 +823,68 @@ class PrePopulate:
) )
return ids return ids
@classmethod
def _import_users(cls, reader: ZipFile, company_id: str = "") -> dict:
"""
Import users to db and return the mapping of old user ids to the new ones
If no users were in the users file then the mapping was empty
If the user in the file has the same email as one of the existing ones then this user is skipped
and its id is mapped to the existing user with the same email
If the user with the same id exists in backend or auth db then its creation is skipped
"""
users_file = first(
fi for fi in reader.filelist if fi.orig_filename == cls.users_filename
)
if not users_file:
return {}
existing_user_ids = set(cls.user_cls.objects().scalar("id")) | set(
cls.auth_user_cls.objects().scalar("id")
)
existing_user_emails = {u.email: u.id for u in cls.auth_user_cls.objects()}
user_id_mappings = {}
with reader.open(users_file) as f:
data = json.loads(f.read())
auth_users = {u["_id"]: u for u in data["auth"]}
be_users = {u["_id"]: u for u in data["backend"]}
for uid, user in auth_users.items():
email = user.get("email")
existing_user_id = existing_user_emails.get(email)
if existing_user_id:
user_id_mappings[uid] = existing_user_id
continue
user_id_mappings[uid] = uid
if uid in existing_user_ids:
continue
credentials = user.get("credentials", [])
for c in credentials:
if c.get("company") == "":
c["company"] = company_id
if hasattr(cls.auth_user_cls, "sec_groups"):
user_role = user.get("role", Role.user)
if user_role == Role.user:
user["sec_groups"] = ["30795571-a470-4717-a80d-e8705fc776bf"]
else:
user["sec_groups"] = [
"c14a3cc6-1144-4896-8ea6-fb186ee19896",
"30795571-a470-4717-a80d-e8705fc776bf",
"30795571a4704717a80de8705897ytuyg",
]
auth_user = cls.auth_user_cls.from_json(json.dumps(user))
auth_user.company = company_id
auth_user.save()
be_user = cls.user_cls.from_json(json.dumps(be_users[uid]))
be_user.company = company_id
be_user.save()
return user_id_mappings
@classmethod @classmethod
def _import( def _import(
cls, cls,
@ -758,6 +893,7 @@ class PrePopulate:
user_id: str = None, user_id: str = None,
metadata: Mapping[str, Any] = None, metadata: Mapping[str, Any] = None,
sort_tasks_by_last_updated: bool = True, sort_tasks_by_last_updated: bool = True,
user_mapping: Mapping[str, str] = None,
): ):
""" """
Import entities and events from the zip file Import entities and events from the zip file
@ -768,7 +904,7 @@ class PrePopulate:
fi fi
for fi in reader.filelist for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending) if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename and fi.orig_filename not in (cls.metadata_filename, cls.users_filename)
] ]
metadata = metadata or {} metadata = metadata or {}
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata) old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
@ -778,7 +914,13 @@ class PrePopulate:
full_name = splitext(entity_file.orig_filename)[0] full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...") print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity( res = cls._import_entity(
f, full_name, company_id, user_id, metadata, old_to_new_ids f,
full_name=full_name,
company_id=company_id,
user_id=user_id,
metadata=metadata,
old_to_new_ids=old_to_new_ids,
user_mapping=user_mapping,
) )
if res: if res:
tasks = res tasks = res
@ -799,7 +941,7 @@ class PrePopulate:
with reader.open(events_file) as f: with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0] full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...") print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, company_id, user_id, task.id) cls._import_events(f, company_id, task.user, task.id)
@classmethod @classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]: def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
@ -913,7 +1055,9 @@ class PrePopulate:
user_id: str, user_id: str,
metadata: Mapping[str, Any], metadata: Mapping[str, Any],
old_to_new_ids: Mapping[str, str] = None, old_to_new_ids: Mapping[str, str] = None,
user_mapping: Mapping[str, str] = None,
) -> Optional[Sequence[Task]]: ) -> Optional[Sequence[Task]]:
user_mapping = user_mapping or {}
cls_ = cls._get_entity_type(full_name) cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database") print(f"Writing {cls_.__name__.lower()}s into database")
tasks = [] tasks = []
@ -935,7 +1079,7 @@ class PrePopulate:
doc = cls_.from_json(item, created=True) doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"): if hasattr(doc, "user"):
doc.user = user_id doc.user = user_mapping.get(doc.user, user_id) if doc.user else user_id
if hasattr(doc, "company"): if hasattr(doc, "company"):
doc.company = company_id doc.company = company_id
if isinstance(doc, cls.project_cls): if isinstance(doc, cls.project_cls):