Improve prepopulate

This commit is contained in:
allegroai 2021-01-05 17:30:37 +02:00
parent 5e0893dd80
commit 7dcc0f6df2

View File

@ -40,7 +40,7 @@ from apiserver.bll.task.param_utils import (
) )
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.config.info import get_default_company from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
@ -48,10 +48,10 @@ from apiserver.database.utils import get_options
from apiserver.tools import safe_get from apiserver.tools import safe_get
from apiserver.utilities import json from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set from apiserver.utilities.dicts import nested_get, nested_set
from .user import _ensure_backend_user
class PrePopulate: class PrePopulate:
module_name_prefix = "apiserver."
event_bll = EventBLL() event_bll = EventBLL()
events_file_suffix = "_events" events_file_suffix = "_events"
export_tag_prefix = "Exported:" export_tag_prefix = "Exported:"
@ -63,6 +63,22 @@ class PrePopulate:
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]", r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
flags=re.IGNORECASE, flags=re.IGNORECASE,
) )
task_cls: Type[Task]
project_cls: Type[Project]
model_cls: Type[Model]
user_cls: Type[User]
# noinspection PyTypeChecker
@classmethod
def _init_entity_types(cls):
if not hasattr(cls, "task_cls"):
cls.task_cls = cls._get_entity_type("database.model.task.task.Task")
if not hasattr(cls, "model_cls"):
cls.model_cls = cls._get_entity_type("database.model.model.Model")
if not hasattr(cls, "project_cls"):
cls.project_cls = cls._get_entity_type("database.model.project.Project")
if not hasattr(cls, "user_cls"):
cls.user_cls = cls._get_entity_type("database.model.User")
class JsonLinesWriter: class JsonLinesWriter:
def __init__(self, file: BinaryIO): def __init__(self, file: BinaryIO):
@ -179,6 +195,8 @@ class PrePopulate:
tag_exported_entities: bool = False, tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None, metadata: Mapping[str, Any] = None,
) -> Sequence[str]: ) -> Sequence[str]:
cls._init_entity_types()
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)): if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses") raise ValueError("Invalid task statuses")
@ -247,6 +265,8 @@ class PrePopulate:
user_id: str = "", user_id: str = "",
user_name: str = "", user_name: str = "",
): ):
cls._init_entity_types()
metadata = None metadata = None
with ZipFile(filename) as zfile: with ZipFile(filename) as zfile:
@ -273,9 +293,7 @@ class PrePopulate:
company_id = "" company_id = ""
# Always use a public user for pre-populated data # Always use a public user for pre-populated data
user_id = _ensure_backend_user( cls.user_cls(id=user_id, name=user_name, company="").save()
user_id=user_id, user_name=user_name, company_id="",
)
cls._import(zfile, company_id, user_id, metadata) cls._import(zfile, company_id, user_id, metadata)
@ -289,7 +307,7 @@ class PrePopulate:
@classmethod @classmethod
def upgrade_zip(cls, filename) -> Sequence: def upgrade_zip(cls, filename) -> Sequence:
hash_ = hashlib.md5() hash_ = hashlib.md5()
task_file = cls._get_base_filename(Task) + ".json" task_file = cls._get_base_filename(cls.task_cls) + ".json"
temp_file = Path("temp.zip") temp_file = Path("temp.zip")
file = Path(filename) file = Path(filename)
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer: with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
@ -357,9 +375,9 @@ class PrePopulate:
with BytesIO() as temp: with BytesIO() as temp:
with cls.JsonLinesWriter(temp) as w: with cls.JsonLinesWriter(temp) as w:
for line in cls.json_lines(f): for line in cls.json_lines(f):
task_data = Task.from_json(line).to_proper_dict() task_data = cls.task_cls.from_json(line).to_proper_dict()
cls._upgrade_task_data(task_data) cls._upgrade_task_data(task_data)
new_task = Task(**task_data) new_task = cls.task_cls(**task_data)
w.write(new_task.to_json()) w.write(new_task.to_json())
return temp.getvalue() return temp.getvalue()
@ -381,9 +399,9 @@ class PrePopulate:
return index return index
return public_default return public_default
for project in Project.get_many_public(projection=["id", "name"]): for project in cls.project_cls.get_many_public(projection=["id", "name"]):
featured_index = get_index(project) featured_index = get_index(project)
Project.objects(id=project.id).update(featured=featured_index) cls.project_cls.objects(id=project.id).update(featured=featured_index)
@staticmethod @staticmethod
def _resolve_type( def _resolve_type(
@ -415,36 +433,44 @@ class PrePopulate:
if projects: if projects:
print("Reading projects...") print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects)) entities[cls.project_cls].update(
cls._resolve_type(cls.project_cls, projects)
)
print("--> Reading project experiments...") print("--> Reading project experiments...")
query = Q( query = Q(
project__in=list(set(filter(None, (p.id for p in entities[Project])))), project__in=list(
set(filter(None, (p.id for p in entities[cls.project_cls])))
),
system_tags__nin=[EntityVisibility.archived.value], system_tags__nin=[EntityVisibility.archived.value],
) )
if task_statuses: if task_statuses:
query &= Q(status__in=list(set(task_statuses))) query &= Q(status__in=list(set(task_statuses)))
objs = Task.objects(query) objs = cls.task_cls.objects(query)
entities[Task].update(o for o in objs if o.id not in (experiments or [])) entities[cls.task_cls].update(
o for o in objs if o.id not in (experiments or [])
)
if experiments: if experiments:
print("Reading experiments...") print("Reading experiments...")
entities[Task].update(cls._resolve_type(Task, experiments)) entities[cls.task_cls].update(cls._resolve_type(cls.task_cls, experiments))
print("--> Reading experiments projects...") print("--> Reading experiments projects...")
objs = Project.objects( objs = cls.project_cls.objects(
id__in=list(set(filter(None, (p.project for p in entities[Task])))) id__in=list(
set(filter(None, (p.project for p in entities[cls.task_cls])))
)
) )
project_ids = {p.id for p in entities[Project]} project_ids = {p.id for p in entities[cls.project_cls]}
entities[Project].update(o for o in objs if o.id not in project_ids) entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
model_ids = { model_ids = {
model_id model_id
for task in entities[Task] for task in entities[cls.task_cls]
for model_id in (task.output.model, task.execution.model) for model_id in (task.output.model, task.execution.model)
if model_id if model_id
} }
if model_ids: if model_ids:
print("Reading models...") print("Reading models...")
entities[Model] = set(Model.objects(id__in=list(model_ids))) entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
return entities return entities
@ -479,11 +505,11 @@ class PrePopulate:
@classmethod @classmethod
def _cleanup_entity(cls, entity_cls, entity): def _cleanup_entity(cls, entity_cls, entity):
if entity_cls == Task: if entity_cls == cls.task_cls:
cls._cleanup_task(entity) cls._cleanup_task(entity)
elif entity_cls == Model: elif entity_cls == cls.model_cls:
cls._cleanup_model(entity) cls._cleanup_model(entity)
elif entity == Project: elif entity == cls.project_cls:
cls._cleanup_project(entity) cls._cleanup_project(entity)
@classmethod @classmethod
@ -549,13 +575,13 @@ class PrePopulate:
def _export_entity_related_data( def _export_entity_related_data(
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_ cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
): ):
if entity_cls == Task: if entity_cls == cls.task_cls:
return [ return [
*cls._get_task_output_artifacts(entity), *cls._get_task_output_artifacts(entity),
*cls._export_task_events(entity, base_filename, writer, hash_), *cls._export_task_events(entity, base_filename, writer, hash_),
] ]
if entity_cls == Model: if entity_cls == cls.model_cls:
entity.uri = cls._get_fixed_url(entity.uri) entity.uri = cls._get_fixed_url(entity.uri)
return [entity.uri] if entity.uri else [] return [entity.uri] if entity.uri else []
@ -590,9 +616,12 @@ class PrePopulate:
else: else:
print(f"Artifact {full_path} not found") print(f"Artifact {full_path} not found")
@staticmethod @classmethod
def _get_base_filename(cls_: type): def _get_base_filename(cls, cls_: type):
return f"{cls_.__module__}.{cls_.__name__}" name = f"{cls_.__module__}.{cls_.__name__}"
if cls.module_name_prefix and name.startswith(cls.module_name_prefix):
name = name[len(cls.module_name_prefix) :]
return name
@classmethod @classmethod
def _export( def _export(
@ -692,6 +721,16 @@ class PrePopulate:
print(f"Reading {reader.filename}:{full_name}...") print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, full_name, company_id, user_id) cls._import_events(f, full_name, company_id, user_id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
module_name, _, class_name = full_name.rpartition(".")
if cls.module_name_prefix and not module_name.startswith(
cls.module_name_prefix
):
module_name = cls.module_name_prefix + module_name
module = importlib.import_module(module_name)
return getattr(module, class_name)
@classmethod @classmethod
def _import_entity( def _import_entity(
cls, cls,
@ -701,14 +740,12 @@ class PrePopulate:
user_id: str, user_id: str,
metadata: Mapping[str, Any], metadata: Mapping[str, Any],
) -> Optional[Sequence[Task]]: ) -> Optional[Sequence[Task]]:
module_name, _, class_name = full_name.rpartition(".") cls_ = cls._get_entity_type(full_name)
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database") print(f"Writing {cls_.__name__.lower()}s into database")
tasks = [] tasks = []
override_project_count = 0 override_project_count = 0
for item in cls.json_lines(f): for item in cls.json_lines(f):
if cls_ == Task: if cls_ == cls.task_cls:
task_data = json.loads(item) task_data = json.loads(item)
artifacts_path = ("execution", "artifacts") artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path) artifacts = nested_get(task_data, artifacts_path)
@ -725,7 +762,7 @@ class PrePopulate:
doc.user = user_id doc.user = user_id
if hasattr(doc, "company"): if hasattr(doc, "company"):
doc.company = company_id doc.company = company_id
if isinstance(doc, Project): if isinstance(doc, cls.project_cls):
override_project_name = metadata.get("project_name", None) override_project_name = metadata.get("project_name", None)
if override_project_name: if override_project_name:
if override_project_count: if override_project_count:
@ -744,7 +781,7 @@ class PrePopulate:
doc.save() doc.save()
if isinstance(doc, Task): if isinstance(doc, cls.task_cls):
tasks.append(doc) tasks.append(doc)
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True) cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)