mirror of
https://github.com/clearml/clearml-server
synced 2025-05-29 17:38:50 +00:00
Improve prepopulate
This commit is contained in:
parent
5e0893dd80
commit
7dcc0f6df2
@ -40,7 +40,7 @@ from apiserver.bll.task.param_utils import (
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_default_company
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model import EntityVisibility, User
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
@ -48,10 +48,10 @@ from apiserver.database.utils import get_options
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from .user import _ensure_backend_user
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
module_name_prefix = "apiserver."
|
||||
event_bll = EventBLL()
|
||||
events_file_suffix = "_events"
|
||||
export_tag_prefix = "Exported:"
|
||||
@ -63,6 +63,22 @@ class PrePopulate:
|
||||
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
task_cls: Type[Task]
|
||||
project_cls: Type[Project]
|
||||
model_cls: Type[Model]
|
||||
user_cls: Type[User]
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@classmethod
|
||||
def _init_entity_types(cls):
|
||||
if not hasattr(cls, "task_cls"):
|
||||
cls.task_cls = cls._get_entity_type("database.model.task.task.Task")
|
||||
if not hasattr(cls, "model_cls"):
|
||||
cls.model_cls = cls._get_entity_type("database.model.model.Model")
|
||||
if not hasattr(cls, "project_cls"):
|
||||
cls.project_cls = cls._get_entity_type("database.model.project.Project")
|
||||
if not hasattr(cls, "user_cls"):
|
||||
cls.user_cls = cls._get_entity_type("database.model.User")
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
@ -179,6 +195,8 @@ class PrePopulate:
|
||||
tag_exported_entities: bool = False,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
) -> Sequence[str]:
|
||||
cls._init_entity_types()
|
||||
|
||||
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
||||
raise ValueError("Invalid task statuses")
|
||||
|
||||
@ -247,6 +265,8 @@ class PrePopulate:
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
):
|
||||
cls._init_entity_types()
|
||||
|
||||
metadata = None
|
||||
|
||||
with ZipFile(filename) as zfile:
|
||||
@ -273,9 +293,7 @@ class PrePopulate:
|
||||
company_id = ""
|
||||
|
||||
# Always use a public user for pre-populated data
|
||||
user_id = _ensure_backend_user(
|
||||
user_id=user_id, user_name=user_name, company_id="",
|
||||
)
|
||||
cls.user_cls(id=user_id, name=user_name, company="").save()
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
@ -289,7 +307,7 @@ class PrePopulate:
|
||||
@classmethod
|
||||
def upgrade_zip(cls, filename) -> Sequence:
|
||||
hash_ = hashlib.md5()
|
||||
task_file = cls._get_base_filename(Task) + ".json"
|
||||
task_file = cls._get_base_filename(cls.task_cls) + ".json"
|
||||
temp_file = Path("temp.zip")
|
||||
file = Path(filename)
|
||||
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
|
||||
@ -357,9 +375,9 @@ class PrePopulate:
|
||||
with BytesIO() as temp:
|
||||
with cls.JsonLinesWriter(temp) as w:
|
||||
for line in cls.json_lines(f):
|
||||
task_data = Task.from_json(line).to_proper_dict()
|
||||
task_data = cls.task_cls.from_json(line).to_proper_dict()
|
||||
cls._upgrade_task_data(task_data)
|
||||
new_task = Task(**task_data)
|
||||
new_task = cls.task_cls(**task_data)
|
||||
w.write(new_task.to_json())
|
||||
return temp.getvalue()
|
||||
|
||||
@ -381,9 +399,9 @@ class PrePopulate:
|
||||
return index
|
||||
return public_default
|
||||
|
||||
for project in Project.get_many_public(projection=["id", "name"]):
|
||||
for project in cls.project_cls.get_many_public(projection=["id", "name"]):
|
||||
featured_index = get_index(project)
|
||||
Project.objects(id=project.id).update(featured=featured_index)
|
||||
cls.project_cls.objects(id=project.id).update(featured=featured_index)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_type(
|
||||
@ -415,36 +433,44 @@ class PrePopulate:
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
entities[Project].update(cls._resolve_type(Project, projects))
|
||||
entities[cls.project_cls].update(
|
||||
cls._resolve_type(cls.project_cls, projects)
|
||||
)
|
||||
print("--> Reading project experiments...")
|
||||
query = Q(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
|
||||
project__in=list(
|
||||
set(filter(None, (p.id for p in entities[cls.project_cls])))
|
||||
),
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
)
|
||||
if task_statuses:
|
||||
query &= Q(status__in=list(set(task_statuses)))
|
||||
objs = Task.objects(query)
|
||||
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
||||
objs = cls.task_cls.objects(query)
|
||||
entities[cls.task_cls].update(
|
||||
o for o in objs if o.id not in (experiments or [])
|
||||
)
|
||||
|
||||
if experiments:
|
||||
print("Reading experiments...")
|
||||
entities[Task].update(cls._resolve_type(Task, experiments))
|
||||
entities[cls.task_cls].update(cls._resolve_type(cls.task_cls, experiments))
|
||||
print("--> Reading experiments projects...")
|
||||
objs = Project.objects(
|
||||
id__in=list(set(filter(None, (p.project for p in entities[Task]))))
|
||||
objs = cls.project_cls.objects(
|
||||
id__in=list(
|
||||
set(filter(None, (p.project for p in entities[cls.task_cls])))
|
||||
)
|
||||
)
|
||||
project_ids = {p.id for p in entities[Project]}
|
||||
entities[Project].update(o for o in objs if o.id not in project_ids)
|
||||
project_ids = {p.id for p in entities[cls.project_cls]}
|
||||
entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
for task in entities[Task]
|
||||
for task in entities[cls.task_cls]
|
||||
for model_id in (task.output.model, task.execution.model)
|
||||
if model_id
|
||||
}
|
||||
if model_ids:
|
||||
print("Reading models...")
|
||||
entities[Model] = set(Model.objects(id__in=list(model_ids)))
|
||||
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
|
||||
|
||||
return entities
|
||||
|
||||
@ -479,11 +505,11 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
if entity_cls == Task:
|
||||
if entity_cls == cls.task_cls:
|
||||
cls._cleanup_task(entity)
|
||||
elif entity_cls == Model:
|
||||
elif entity_cls == cls.model_cls:
|
||||
cls._cleanup_model(entity)
|
||||
elif entity == Project:
|
||||
elif entity == cls.project_cls:
|
||||
cls._cleanup_project(entity)
|
||||
|
||||
@classmethod
|
||||
@ -549,13 +575,13 @@ class PrePopulate:
|
||||
def _export_entity_related_data(
|
||||
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
|
||||
):
|
||||
if entity_cls == Task:
|
||||
if entity_cls == cls.task_cls:
|
||||
return [
|
||||
*cls._get_task_output_artifacts(entity),
|
||||
*cls._export_task_events(entity, base_filename, writer, hash_),
|
||||
]
|
||||
|
||||
if entity_cls == Model:
|
||||
if entity_cls == cls.model_cls:
|
||||
entity.uri = cls._get_fixed_url(entity.uri)
|
||||
return [entity.uri] if entity.uri else []
|
||||
|
||||
@ -590,9 +616,12 @@ class PrePopulate:
|
||||
else:
|
||||
print(f"Artifact {full_path} not found")
|
||||
|
||||
@staticmethod
|
||||
def _get_base_filename(cls_: type):
|
||||
return f"{cls_.__module__}.{cls_.__name__}"
|
||||
@classmethod
|
||||
def _get_base_filename(cls, cls_: type):
|
||||
name = f"{cls_.__module__}.{cls_.__name__}"
|
||||
if cls.module_name_prefix and name.startswith(cls.module_name_prefix):
|
||||
name = name[len(cls.module_name_prefix) :]
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
@ -692,6 +721,16 @@ class PrePopulate:
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
cls._import_events(f, full_name, company_id, user_id)
|
||||
|
||||
@classmethod
|
||||
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
if cls.module_name_prefix and not module_name.startswith(
|
||||
cls.module_name_prefix
|
||||
):
|
||||
module_name = cls.module_name_prefix + module_name
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
@classmethod
|
||||
def _import_entity(
|
||||
cls,
|
||||
@ -701,14 +740,12 @@ class PrePopulate:
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
) -> Optional[Sequence[Task]]:
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
cls_ = cls._get_entity_type(full_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
tasks = []
|
||||
override_project_count = 0
|
||||
for item in cls.json_lines(f):
|
||||
if cls_ == Task:
|
||||
if cls_ == cls.task_cls:
|
||||
task_data = json.loads(item)
|
||||
artifacts_path = ("execution", "artifacts")
|
||||
artifacts = nested_get(task_data, artifacts_path)
|
||||
@ -725,7 +762,7 @@ class PrePopulate:
|
||||
doc.user = user_id
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, Project):
|
||||
if isinstance(doc, cls.project_cls):
|
||||
override_project_name = metadata.get("project_name", None)
|
||||
if override_project_name:
|
||||
if override_project_count:
|
||||
@ -744,7 +781,7 @@ class PrePopulate:
|
||||
|
||||
doc.save()
|
||||
|
||||
if isinstance(doc, Task):
|
||||
if isinstance(doc, cls.task_cls):
|
||||
tasks.append(doc)
|
||||
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user