Improve prepopulate

This commit is contained in:
allegroai 2021-01-05 17:30:37 +02:00
parent 5e0893dd80
commit 7dcc0f6df2

View File

@ -40,7 +40,7 @@ from apiserver.bll.task.param_utils import (
)
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
@ -48,10 +48,10 @@ from apiserver.database.utils import get_options
from apiserver.tools import safe_get
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set
from .user import _ensure_backend_user
class PrePopulate:
module_name_prefix = "apiserver."
event_bll = EventBLL()
events_file_suffix = "_events"
export_tag_prefix = "Exported:"
@ -63,6 +63,22 @@ class PrePopulate:
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
flags=re.IGNORECASE,
)
task_cls: Type[Task]
project_cls: Type[Project]
model_cls: Type[Model]
user_cls: Type[User]
# noinspection PyTypeChecker
@classmethod
def _init_entity_types(cls):
if not hasattr(cls, "task_cls"):
cls.task_cls = cls._get_entity_type("database.model.task.task.Task")
if not hasattr(cls, "model_cls"):
cls.model_cls = cls._get_entity_type("database.model.model.Model")
if not hasattr(cls, "project_cls"):
cls.project_cls = cls._get_entity_type("database.model.project.Project")
if not hasattr(cls, "user_cls"):
cls.user_cls = cls._get_entity_type("database.model.User")
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
@ -179,6 +195,8 @@ class PrePopulate:
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
) -> Sequence[str]:
cls._init_entity_types()
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses")
@ -247,6 +265,8 @@ class PrePopulate:
user_id: str = "",
user_name: str = "",
):
cls._init_entity_types()
metadata = None
with ZipFile(filename) as zfile:
@ -273,9 +293,7 @@ class PrePopulate:
company_id = ""
# Always use a public user for pre-populated data
user_id = _ensure_backend_user(
user_id=user_id, user_name=user_name, company_id="",
)
cls.user_cls(id=user_id, name=user_name, company="").save()
cls._import(zfile, company_id, user_id, metadata)
@ -289,7 +307,7 @@ class PrePopulate:
@classmethod
def upgrade_zip(cls, filename) -> Sequence:
hash_ = hashlib.md5()
task_file = cls._get_base_filename(Task) + ".json"
task_file = cls._get_base_filename(cls.task_cls) + ".json"
temp_file = Path("temp.zip")
file = Path(filename)
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
@ -357,9 +375,9 @@ class PrePopulate:
with BytesIO() as temp:
with cls.JsonLinesWriter(temp) as w:
for line in cls.json_lines(f):
task_data = Task.from_json(line).to_proper_dict()
task_data = cls.task_cls.from_json(line).to_proper_dict()
cls._upgrade_task_data(task_data)
new_task = Task(**task_data)
new_task = cls.task_cls(**task_data)
w.write(new_task.to_json())
return temp.getvalue()
@ -381,9 +399,9 @@ class PrePopulate:
return index
return public_default
for project in Project.get_many_public(projection=["id", "name"]):
for project in cls.project_cls.get_many_public(projection=["id", "name"]):
featured_index = get_index(project)
Project.objects(id=project.id).update(featured=featured_index)
cls.project_cls.objects(id=project.id).update(featured=featured_index)
@staticmethod
def _resolve_type(
@ -415,36 +433,44 @@ class PrePopulate:
if projects:
print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects))
entities[cls.project_cls].update(
cls._resolve_type(cls.project_cls, projects)
)
print("--> Reading project experiments...")
query = Q(
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
project__in=list(
set(filter(None, (p.id for p in entities[cls.project_cls])))
),
system_tags__nin=[EntityVisibility.archived.value],
)
if task_statuses:
query &= Q(status__in=list(set(task_statuses)))
objs = Task.objects(query)
entities[Task].update(o for o in objs if o.id not in (experiments or []))
objs = cls.task_cls.objects(query)
entities[cls.task_cls].update(
o for o in objs if o.id not in (experiments or [])
)
if experiments:
print("Reading experiments...")
entities[Task].update(cls._resolve_type(Task, experiments))
entities[cls.task_cls].update(cls._resolve_type(cls.task_cls, experiments))
print("--> Reading experiments projects...")
objs = Project.objects(
id__in=list(set(filter(None, (p.project for p in entities[Task]))))
objs = cls.project_cls.objects(
id__in=list(
set(filter(None, (p.project for p in entities[cls.task_cls])))
)
)
project_ids = {p.id for p in entities[Project]}
entities[Project].update(o for o in objs if o.id not in project_ids)
project_ids = {p.id for p in entities[cls.project_cls]}
entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
model_ids = {
model_id
for task in entities[Task]
for task in entities[cls.task_cls]
for model_id in (task.output.model, task.execution.model)
if model_id
}
if model_ids:
print("Reading models...")
entities[Model] = set(Model.objects(id__in=list(model_ids)))
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
return entities
@ -479,11 +505,11 @@ class PrePopulate:
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
if entity_cls == Task:
if entity_cls == cls.task_cls:
cls._cleanup_task(entity)
elif entity_cls == Model:
elif entity_cls == cls.model_cls:
cls._cleanup_model(entity)
elif entity == Project:
elif entity == cls.project_cls:
cls._cleanup_project(entity)
@classmethod
@ -549,13 +575,13 @@ class PrePopulate:
def _export_entity_related_data(
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
):
if entity_cls == Task:
if entity_cls == cls.task_cls:
return [
*cls._get_task_output_artifacts(entity),
*cls._export_task_events(entity, base_filename, writer, hash_),
]
if entity_cls == Model:
if entity_cls == cls.model_cls:
entity.uri = cls._get_fixed_url(entity.uri)
return [entity.uri] if entity.uri else []
@ -590,9 +616,12 @@ class PrePopulate:
else:
print(f"Artifact {full_path} not found")
@staticmethod
def _get_base_filename(cls_: type):
return f"{cls_.__module__}.{cls_.__name__}"
@classmethod
def _get_base_filename(cls, cls_: type):
name = f"{cls_.__module__}.{cls_.__name__}"
if cls.module_name_prefix and name.startswith(cls.module_name_prefix):
name = name[len(cls.module_name_prefix) :]
return name
@classmethod
def _export(
@ -692,6 +721,16 @@ class PrePopulate:
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, full_name, company_id, user_id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
module_name, _, class_name = full_name.rpartition(".")
if cls.module_name_prefix and not module_name.startswith(
cls.module_name_prefix
):
module_name = cls.module_name_prefix + module_name
module = importlib.import_module(module_name)
return getattr(module, class_name)
@classmethod
def _import_entity(
cls,
@ -701,14 +740,12 @@ class PrePopulate:
user_id: str,
metadata: Mapping[str, Any],
) -> Optional[Sequence[Task]]:
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
override_project_count = 0
for item in cls.json_lines(f):
if cls_ == Task:
if cls_ == cls.task_cls:
task_data = json.loads(item)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
@ -725,7 +762,7 @@ class PrePopulate:
doc.user = user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, Project):
if isinstance(doc, cls.project_cls):
override_project_name = metadata.get("project_name", None)
if override_project_name:
if override_project_count:
@ -744,7 +781,7 @@ class PrePopulate:
doc.save()
if isinstance(doc, Task):
if isinstance(doc, cls.task_cls):
tasks.append(doc)
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)