mirror of
https://github.com/clearml/clearml-server
synced 2025-02-01 03:16:44 +00:00
154 lines
5.5 KiB
Python
154 lines
5.5 KiB
Python
import importlib
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from os.path import splitext
|
|
from typing import List, Optional, Any, Type, Set, Dict
|
|
from zipfile import ZipFile, ZIP_BZIP2
|
|
|
|
import mongoengine
|
|
from tqdm import tqdm
|
|
|
|
|
|
class PrePopulate:
|
|
@classmethod
|
|
def export_to_zip(
|
|
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
|
|
):
|
|
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
|
|
cls._export(zfile, experiments, projects)
|
|
|
|
@classmethod
|
|
def import_from_zip(cls, filename: str, user_id: str = None):
|
|
with ZipFile(filename) as zfile:
|
|
cls._import(zfile, user_id)
|
|
|
|
@staticmethod
|
|
def _resolve_type(
|
|
cls: Type[mongoengine.Document], ids: Optional[List[str]]
|
|
) -> List[Any]:
|
|
ids = set(ids)
|
|
items = list(cls.objects(id__in=list(ids)))
|
|
resolved = {i.id for i in items}
|
|
missing = ids - resolved
|
|
for name_candidate in missing:
|
|
results = list(cls.objects(name=name_candidate))
|
|
if not results:
|
|
print(f"ERROR: no match for `{name_candidate}`")
|
|
exit(1)
|
|
elif len(results) > 1:
|
|
print(f"ERROR: more than one match for `{name_candidate}`")
|
|
exit(1)
|
|
items.append(results[0])
|
|
return items
|
|
|
|
@classmethod
|
|
def _resolve_entities(
|
|
cls, experiments: List[str] = None, projects: List[str] = None
|
|
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
|
from database.model.project import Project
|
|
from database.model.task.task import Task
|
|
|
|
entities = defaultdict(set)
|
|
|
|
if projects:
|
|
print("Reading projects...")
|
|
entities[Project].update(cls._resolve_type(Project, projects))
|
|
print("--> Reading project experiments...")
|
|
objs = Task.objects(
|
|
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
|
|
)
|
|
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
|
|
|
if experiments:
|
|
print("Reading experiments...")
|
|
entities[Task].update(cls._resolve_type(Task, experiments))
|
|
print("--> Reading experiments projects...")
|
|
objs = Project.objects(
|
|
id__in=list(set(filter(None, (p.project for p in entities[Task]))))
|
|
)
|
|
project_ids = {p.id for p in entities[Project]}
|
|
entities[Project].update(o for o in objs if o.id not in project_ids)
|
|
|
|
return entities
|
|
|
|
@classmethod
|
|
def _cleanup_task(cls, task):
|
|
from database.model.task.task import TaskStatus
|
|
|
|
task.completed = None
|
|
task.started = None
|
|
if task.execution:
|
|
task.execution.model = None
|
|
task.execution.model_desc = None
|
|
task.execution.model_labels = None
|
|
if task.output:
|
|
task.output.model = None
|
|
|
|
task.status = TaskStatus.created
|
|
task.comment = "Auto generated by Allegro.ai"
|
|
task.created = datetime.utcnow()
|
|
task.last_iteration = 0
|
|
task.last_update = task.created
|
|
task.status_changed = task.created
|
|
task.status_message = ""
|
|
task.status_reason = ""
|
|
task.user = ""
|
|
|
|
@classmethod
|
|
def _cleanup_entity(cls, entity_cls, entity):
|
|
from database.model.task.task import Task
|
|
if entity_cls == Task:
|
|
cls._cleanup_task(entity)
|
|
|
|
@classmethod
|
|
def _export(
|
|
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
|
|
):
|
|
entities = cls._resolve_entities(experiments, projects)
|
|
|
|
for cls_, items in entities.items():
|
|
if not items:
|
|
continue
|
|
filename = f"{cls_.__module__}.{cls_.__name__}.json"
|
|
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
|
with writer.open(filename, "w") as f:
|
|
f.write("[\n".encode("utf-8"))
|
|
last = len(items) - 1
|
|
for i, item in enumerate(items):
|
|
cls._cleanup_entity(cls_, item)
|
|
f.write(item.to_json().encode("utf-8"))
|
|
if i != last:
|
|
f.write(",".encode("utf-8"))
|
|
f.write("\n".encode("utf-8"))
|
|
f.write("]\n".encode("utf-8"))
|
|
|
|
@staticmethod
|
|
def _import(reader: ZipFile, user_id: str = None):
|
|
for file_info in reader.filelist:
|
|
full_name = splitext(file_info.orig_filename)[0]
|
|
print(f"Reading {reader.filename}:{full_name}...")
|
|
module_name, _, class_name = full_name.rpartition(".")
|
|
module = importlib.import_module(module_name)
|
|
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
|
|
|
with reader.open(file_info) as f:
|
|
for item in tqdm(
|
|
f.readlines(),
|
|
desc=f"Writing {cls_.__name__.lower()}s into database",
|
|
unit="doc",
|
|
):
|
|
item = (
|
|
item.decode("utf-8")
|
|
.strip()
|
|
.lstrip("[")
|
|
.rstrip("]")
|
|
.rstrip(",")
|
|
.strip()
|
|
)
|
|
if not item:
|
|
continue
|
|
doc = cls_.from_json(item)
|
|
if user_id is not None and hasattr(doc, "user"):
|
|
doc.user = user_id
|
|
doc.save(force_insert=True)
|