mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 19:06:55 +00:00
875 lines
31 KiB
Python
875 lines
31 KiB
Python
import hashlib
|
|
import importlib
|
|
import os
|
|
import re
|
|
from collections import defaultdict
|
|
from datetime import datetime, timezone
|
|
from io import BytesIO
|
|
from itertools import chain
|
|
from operator import attrgetter
|
|
from os.path import splitext
|
|
from pathlib import Path
|
|
from typing import (
|
|
Optional,
|
|
Any,
|
|
Type,
|
|
Set,
|
|
Dict,
|
|
Sequence,
|
|
Tuple,
|
|
BinaryIO,
|
|
Union,
|
|
Mapping,
|
|
IO,
|
|
)
|
|
from urllib.parse import unquote, urlparse
|
|
from zipfile import ZipFile, ZIP_BZIP2
|
|
|
|
import mongoengine
|
|
from boltons.iterutils import chunked_iter, first
|
|
from furl import furl
|
|
from mongoengine import Q
|
|
|
|
from apiserver.bll.event import EventBLL
|
|
from apiserver.bll.event.event_common import EventType
|
|
from apiserver.bll.project import project_ids_with_children
|
|
from apiserver.bll.task.artifacts import get_artifact_id
|
|
from apiserver.bll.task.param_utils import (
|
|
split_param_name,
|
|
hyperparams_default_section,
|
|
hyperparams_legacy_type,
|
|
)
|
|
from apiserver.config_repo import config
|
|
from apiserver.config.info import get_default_company
|
|
from apiserver.database.model import EntityVisibility, User
|
|
from apiserver.database.model.model import Model
|
|
from apiserver.database.model.project import Project
|
|
from apiserver.database.model.task.task import (
|
|
Task,
|
|
ArtifactModes,
|
|
TaskStatus,
|
|
TaskModelTypes,
|
|
TaskModelNames,
|
|
)
|
|
from apiserver.database.utils import get_options
|
|
from apiserver.utilities import json
|
|
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
|
|
|
|
|
class PrePopulate:
|
|
module_name_prefix = "apiserver."
|
|
event_bll = EventBLL()
|
|
events_file_suffix = "_events"
|
|
export_tag_prefix = "Exported:"
|
|
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
|
metadata_filename = "metadata.json"
|
|
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
|
artifacts_ext = ".artifacts"
|
|
img_source_regex = re.compile(
|
|
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
|
|
flags=re.IGNORECASE,
|
|
)
|
|
task_cls: Type[Task]
|
|
project_cls: Type[Project]
|
|
model_cls: Type[Model]
|
|
user_cls: Type[User]
|
|
|
|
# noinspection PyTypeChecker
|
|
@classmethod
|
|
def _init_entity_types(cls):
|
|
if not hasattr(cls, "task_cls"):
|
|
cls.task_cls = cls._get_entity_type("database.model.task.task.Task")
|
|
if not hasattr(cls, "model_cls"):
|
|
cls.model_cls = cls._get_entity_type("database.model.model.Model")
|
|
if not hasattr(cls, "project_cls"):
|
|
cls.project_cls = cls._get_entity_type("database.model.project.Project")
|
|
if not hasattr(cls, "user_cls"):
|
|
cls.user_cls = cls._get_entity_type("database.model.User")
|
|
|
|
class JsonLinesWriter:
|
|
def __init__(self, file: BinaryIO):
|
|
self.file = file
|
|
self.empty = True
|
|
|
|
def __enter__(self):
|
|
self._write("[")
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
self._write("\n]")
|
|
|
|
def _write(self, data: str):
|
|
self.file.write(data.encode("utf-8"))
|
|
|
|
def write(self, line: str):
|
|
if not self.empty:
|
|
self._write(",")
|
|
self._write("\n" + line)
|
|
self.empty = False
|
|
|
|
@staticmethod
|
|
def _get_last_update_time(entity) -> datetime:
|
|
return getattr(entity, "last_update", None) or getattr(entity, "created")
|
|
|
|
@classmethod
|
|
def _check_for_update(
|
|
cls, map_file: Path, entities: dict, metadata_hash: str
|
|
) -> Tuple[bool, Sequence[str]]:
|
|
if not map_file.is_file():
|
|
return True, []
|
|
|
|
files = []
|
|
try:
|
|
map_data = json.loads(map_file.read_text())
|
|
files = map_data.get("files", [])
|
|
for file in files:
|
|
if not Path(file).is_file():
|
|
return True, files
|
|
|
|
new_times = {
|
|
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
|
|
for item in chain.from_iterable(entities.values())
|
|
}
|
|
old_times = map_data.get("entities", {})
|
|
|
|
if set(new_times.keys()) != set(old_times.keys()):
|
|
return True, files
|
|
|
|
for id_, new_timestamp in new_times.items():
|
|
if new_timestamp != old_times[id_]:
|
|
return True, files
|
|
|
|
if metadata_hash != map_data.get("metadata_hash", ""):
|
|
return True, files
|
|
|
|
except Exception as ex:
|
|
print("Error reading map file. " + str(ex))
|
|
return True, files
|
|
|
|
return False, files
|
|
|
|
@classmethod
|
|
def _write_update_file(
|
|
cls,
|
|
map_file: Path,
|
|
entities: dict,
|
|
created_files: Sequence[str],
|
|
metadata_hash: str,
|
|
):
|
|
map_file.write_text(
|
|
json.dumps(
|
|
dict(
|
|
files=created_files,
|
|
entities={
|
|
entity.id: cls._get_last_update_time(entity)
|
|
for entity in chain.from_iterable(entities.values())
|
|
},
|
|
metadata_hash=metadata_hash,
|
|
)
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
|
|
def is_fileserver_link(a: str) -> bool:
|
|
a = a.lower()
|
|
if a.startswith("https://files."):
|
|
return True
|
|
if a.startswith("http"):
|
|
parsed = urlparse(a)
|
|
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
|
|
"8081"
|
|
):
|
|
return True
|
|
return False
|
|
|
|
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
|
|
print(
|
|
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
|
|
)
|
|
|
|
return fileserver_links
|
|
|
|
@classmethod
|
|
def export_to_zip(
|
|
cls,
|
|
filename: str,
|
|
experiments: Sequence[str] = None,
|
|
projects: Sequence[str] = None,
|
|
artifacts_path: str = None,
|
|
task_statuses: Sequence[str] = None,
|
|
tag_exported_entities: bool = False,
|
|
metadata: Mapping[str, Any] = None,
|
|
) -> Sequence[str]:
|
|
cls._init_entity_types()
|
|
|
|
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
|
raise ValueError("Invalid task statuses")
|
|
|
|
file = Path(filename)
|
|
entities = cls._resolve_entities(
|
|
experiments=experiments, projects=projects, task_statuses=task_statuses
|
|
)
|
|
|
|
hash_ = hashlib.md5()
|
|
if metadata:
|
|
meta_str = json.dumps(metadata)
|
|
hash_.update(meta_str.encode())
|
|
metadata_hash = hash_.hexdigest()
|
|
else:
|
|
meta_str, metadata_hash = "", ""
|
|
|
|
map_file = file.with_suffix(".map")
|
|
updated, old_files = cls._check_for_update(
|
|
map_file, entities=entities, metadata_hash=metadata_hash
|
|
)
|
|
if not updated:
|
|
print(f"There are no updates from the last export")
|
|
return old_files
|
|
|
|
for old in old_files:
|
|
old_path = Path(old)
|
|
if old_path.is_file():
|
|
old_path.unlink()
|
|
|
|
with ZipFile(file, **cls.zip_args) as zfile:
|
|
if metadata:
|
|
zfile.writestr(cls.metadata_filename, meta_str)
|
|
artifacts = cls._export(
|
|
zfile,
|
|
entities=entities,
|
|
hash_=hash_,
|
|
tag_entities=tag_exported_entities,
|
|
)
|
|
|
|
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
|
|
file.replace(file_with_hash)
|
|
created_files = [str(file_with_hash)]
|
|
|
|
artifacts = cls._filter_artifacts(artifacts)
|
|
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
|
|
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
|
|
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
|
|
cls._export_artifacts(zfile, artifacts, artifacts_path)
|
|
created_files.append(str(artifacts_file))
|
|
|
|
cls._write_update_file(
|
|
map_file,
|
|
entities=entities,
|
|
created_files=created_files,
|
|
metadata_hash=metadata_hash,
|
|
)
|
|
|
|
return created_files
|
|
|
|
@classmethod
|
|
def import_from_zip(
|
|
cls,
|
|
filename: str,
|
|
artifacts_path: str,
|
|
company_id: Optional[str] = None,
|
|
user_id: str = "",
|
|
user_name: str = "",
|
|
):
|
|
cls._init_entity_types()
|
|
|
|
metadata = None
|
|
|
|
with ZipFile(filename) as zfile:
|
|
try:
|
|
with zfile.open(cls.metadata_filename) as f:
|
|
metadata = json.loads(f.read())
|
|
|
|
meta_public = metadata.get("public")
|
|
if company_id is None and meta_public is not None:
|
|
company_id = "" if meta_public else get_default_company()
|
|
|
|
if not user_id:
|
|
meta_user_id = metadata.get("user_id", "")
|
|
meta_user_name = metadata.get("user_name", "")
|
|
user_id, user_name = meta_user_id, meta_user_name
|
|
except Exception:
|
|
pass
|
|
|
|
if not user_id:
|
|
user_id, user_name = "__allegroai__", "Allegro.ai"
|
|
|
|
# Make sure we won't end up with an invalid company ID
|
|
if company_id is None:
|
|
company_id = ""
|
|
|
|
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
|
|
if not existing_user:
|
|
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
|
|
|
|
cls._import(zfile, company_id, user_id, metadata)
|
|
|
|
if artifacts_path and os.path.isdir(artifacts_path):
|
|
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
|
|
if artifacts_file.is_file():
|
|
print(f"Unzipping artifacts into {artifacts_path}")
|
|
with ZipFile(artifacts_file) as zfile:
|
|
zfile.extractall(artifacts_path)
|
|
|
|
@classmethod
|
|
def upgrade_zip(cls, filename) -> Sequence:
|
|
hash_ = hashlib.md5()
|
|
task_file = cls._get_base_filename(cls.task_cls) + ".json"
|
|
temp_file = Path("temp.zip")
|
|
file = Path(filename)
|
|
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
|
|
for file_info in reader.filelist:
|
|
if file_info.orig_filename == task_file:
|
|
with reader.open(file_info) as f:
|
|
content = cls._upgrade_tasks(f)
|
|
else:
|
|
content = reader.read(file_info)
|
|
writer.writestr(file_info, content)
|
|
hash_.update(content)
|
|
|
|
base_file_name, _, old_hash = file.stem.rpartition("_")
|
|
new_hash = hash_.hexdigest()
|
|
if old_hash == new_hash:
|
|
print(f"The file {filename} was not updated")
|
|
temp_file.unlink()
|
|
return []
|
|
|
|
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
|
|
temp_file.replace(new_file)
|
|
upadated = [str(new_file)]
|
|
|
|
artifacts_file = file.with_suffix(cls.artifacts_ext)
|
|
if artifacts_file.is_file():
|
|
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
|
|
artifacts_file.replace(new_artifacts)
|
|
upadated.append(str(new_artifacts))
|
|
|
|
print(f"File {str(file)} replaced with {str(new_file)}")
|
|
file.unlink()
|
|
|
|
return upadated
|
|
|
|
@classmethod
|
|
def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
|
|
"""
|
|
Build content array that contains upgraded tasks from the passed file
|
|
For each task the old execution.parameters and model.design are
|
|
converted to the new structure.
|
|
The fix is done on Task objects (not the dictionary) so that
|
|
the fields are serialized back in the same order as they were in the original file
|
|
"""
|
|
with BytesIO() as temp:
|
|
with cls.JsonLinesWriter(temp) as w:
|
|
for line in cls.json_lines(f):
|
|
task_data = cls.task_cls.from_json(line).to_proper_dict()
|
|
cls._upgrade_task_data(task_data)
|
|
new_task = cls.task_cls(**task_data)
|
|
w.write(new_task.to_json())
|
|
return temp.getvalue()
|
|
|
|
@classmethod
|
|
def update_featured_projects_order(cls):
|
|
order = config.get("services.projects.featured.order", [])
|
|
if not order:
|
|
return
|
|
|
|
public_default = config.get("services.projects.featured.public_default", 9999)
|
|
|
|
def get_index(p: Project):
|
|
for index, entry in enumerate(order):
|
|
if (
|
|
entry.get("id", None) == p.id
|
|
or entry.get("name", None) == p.name
|
|
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
|
|
):
|
|
return index
|
|
return public_default
|
|
|
|
for project in cls.project_cls.get_many_public(projection=["id", "name"]):
|
|
featured_index = get_index(project)
|
|
cls.project_cls.objects(id=project.id).update(featured=featured_index)
|
|
|
|
@staticmethod
|
|
def _resolve_type(
|
|
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
|
|
) -> Sequence[Any]:
|
|
ids = set(ids)
|
|
items = list(cls.objects(id__in=list(ids)))
|
|
resolved = {i.id for i in items}
|
|
missing = ids - resolved
|
|
for name_candidate in missing:
|
|
results = list(cls.objects(name=name_candidate))
|
|
if not results:
|
|
print(f"ERROR: no match for `{name_candidate}`")
|
|
exit(1)
|
|
elif len(results) > 1:
|
|
print(f"ERROR: more than one match for `{name_candidate}`")
|
|
exit(1)
|
|
items.append(results[0])
|
|
return items
|
|
|
|
@classmethod
|
|
def _check_projects_hierarchy(cls, projects: Set[Project]):
|
|
"""
|
|
For any exported project all its parents up to the root should be present
|
|
"""
|
|
if not projects:
|
|
return
|
|
|
|
project_ids = {p.id for p in projects}
|
|
orphans = [p.id for p in projects if p.parent and p.parent not in project_ids]
|
|
if not orphans:
|
|
return
|
|
|
|
print(
|
|
f"ERROR: the following projects are exported without their parents: {orphans}"
|
|
)
|
|
exit(1)
|
|
|
|
@classmethod
|
|
def _resolve_entities(
|
|
cls,
|
|
experiments: Sequence[str] = None,
|
|
projects: Sequence[str] = None,
|
|
task_statuses: Sequence[str] = None,
|
|
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
|
entities = defaultdict(set)
|
|
|
|
if projects:
|
|
print("Reading projects...")
|
|
projects = project_ids_with_children(projects)
|
|
entities[cls.project_cls].update(
|
|
cls._resolve_type(cls.project_cls, projects)
|
|
)
|
|
print("--> Reading project experiments...")
|
|
query = Q(
|
|
project__in=list(
|
|
set(filter(None, (p.id for p in entities[cls.project_cls])))
|
|
),
|
|
system_tags__nin=[EntityVisibility.archived.value],
|
|
)
|
|
if task_statuses:
|
|
query &= Q(status__in=list(set(task_statuses)))
|
|
objs = cls.task_cls.objects(query)
|
|
entities[cls.task_cls].update(
|
|
o for o in objs if o.id not in (experiments or [])
|
|
)
|
|
|
|
if experiments:
|
|
print("Reading experiments...")
|
|
entities[cls.task_cls].update(cls._resolve_type(cls.task_cls, experiments))
|
|
print("--> Reading experiments projects...")
|
|
objs = cls.project_cls.objects(
|
|
id__in=list(
|
|
set(filter(None, (p.project for p in entities[cls.task_cls])))
|
|
)
|
|
)
|
|
project_ids = {p.id for p in entities[cls.project_cls]}
|
|
entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
|
|
|
|
cls._check_projects_hierarchy(entities[cls.project_cls])
|
|
|
|
task_models = chain.from_iterable(
|
|
models
|
|
for task in entities[cls.task_cls]
|
|
if task.models
|
|
for models in (task.models.input, task.models.output)
|
|
if models
|
|
)
|
|
model_ids = {tm.model for tm in task_models}
|
|
if model_ids:
|
|
print("Reading models...")
|
|
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
|
|
|
|
return entities
|
|
|
|
@classmethod
|
|
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
|
|
if not tags:
|
|
return tags
|
|
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
|
|
|
|
@classmethod
|
|
def _cleanup_model(cls, model: Model):
|
|
model.company = ""
|
|
model.user = ""
|
|
model.tags = cls._filter_out_export_tags(model.tags)
|
|
|
|
@classmethod
|
|
def _cleanup_task(cls, task: Task):
|
|
task.comment = "Auto generated by Allegro.ai"
|
|
task.status_message = ""
|
|
task.status_reason = ""
|
|
task.user = ""
|
|
task.company = ""
|
|
task.tags = cls._filter_out_export_tags(task.tags)
|
|
if task.output:
|
|
task.output.destination = None
|
|
|
|
@classmethod
|
|
def _cleanup_project(cls, project: Project):
|
|
project.user = ""
|
|
project.company = ""
|
|
project.tags = cls._filter_out_export_tags(project.tags)
|
|
|
|
@classmethod
|
|
def _cleanup_entity(cls, entity_cls, entity):
|
|
if entity_cls == cls.task_cls:
|
|
cls._cleanup_task(entity)
|
|
elif entity_cls == cls.model_cls:
|
|
cls._cleanup_model(entity)
|
|
elif entity == cls.project_cls:
|
|
cls._cleanup_project(entity)
|
|
|
|
@classmethod
|
|
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
|
|
try:
|
|
for item in items:
|
|
item.update(upsert=False, tags=sorted(item.tags + [tag]))
|
|
except AttributeError:
|
|
pass
|
|
|
|
@classmethod
|
|
def _export_task_events(
|
|
cls, task: Task, base_filename: str, writer: ZipFile, hash_
|
|
) -> Sequence[str]:
|
|
artifacts = []
|
|
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
|
|
print(f"Writing task events into {writer.filename}:{filename}")
|
|
with BytesIO() as f:
|
|
with cls.JsonLinesWriter(f) as w:
|
|
scroll_id = None
|
|
while True:
|
|
res = cls.event_bll.get_task_events(
|
|
company_id=task.company,
|
|
task_id=task.id,
|
|
event_type=EventType.all,
|
|
scroll_id=scroll_id,
|
|
)
|
|
if not res.events:
|
|
break
|
|
scroll_id = res.next_scroll_id
|
|
for event in res.events:
|
|
event_type = event.get("type")
|
|
if event_type == EventType.metrics_image.value:
|
|
url = cls._get_fixed_url(event.get("url"))
|
|
if url:
|
|
event["url"] = url
|
|
artifacts.append(url)
|
|
elif event_type == EventType.metrics_plot.value:
|
|
plot_str: str = event.get("plot_str", "")
|
|
for match in cls.img_source_regex.findall(plot_str):
|
|
url = cls._get_fixed_url(match)
|
|
if match != url:
|
|
plot_str = plot_str.replace(match, url)
|
|
artifacts.append(url)
|
|
w.write(json.dumps(event))
|
|
data = f.getvalue()
|
|
hash_.update(data)
|
|
writer.writestr(filename, data)
|
|
|
|
return artifacts
|
|
|
|
@staticmethod
|
|
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
|
|
if not (url and url.lower().startswith("s3://")):
|
|
return url
|
|
try:
|
|
fixed = furl(url)
|
|
fixed.scheme = "https"
|
|
fixed.host += ".s3.amazonaws.com"
|
|
return fixed.url
|
|
except Exception as ex:
|
|
print(f"Failed processing link {url}. " + str(ex))
|
|
return url
|
|
|
|
@classmethod
|
|
def _export_entity_related_data(
|
|
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
|
|
):
|
|
if entity_cls == cls.task_cls:
|
|
return [
|
|
*cls._get_task_output_artifacts(entity),
|
|
*cls._export_task_events(entity, base_filename, writer, hash_),
|
|
]
|
|
|
|
if entity_cls == cls.model_cls:
|
|
entity.uri = cls._get_fixed_url(entity.uri)
|
|
return [entity.uri] if entity.uri else []
|
|
|
|
return []
|
|
|
|
@classmethod
|
|
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
|
|
if not task.execution.artifacts:
|
|
return []
|
|
|
|
for a in task.execution.artifacts.values():
|
|
if a.mode == ArtifactModes.output:
|
|
a.uri = cls._get_fixed_url(a.uri)
|
|
|
|
return [
|
|
a.uri
|
|
for a in task.execution.artifacts.values()
|
|
if a.mode == ArtifactModes.output and a.uri
|
|
]
|
|
|
|
@classmethod
|
|
def _export_artifacts(
|
|
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
|
|
):
|
|
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
|
|
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
|
|
for path in unique_paths:
|
|
path = path.lstrip("/")
|
|
full_path = os.path.join(artifacts_path, path)
|
|
if os.path.isfile(full_path):
|
|
writer.write(full_path, path)
|
|
else:
|
|
print(f"Artifact {full_path} not found")
|
|
|
|
@classmethod
|
|
def _get_base_filename(cls, cls_: type):
|
|
name = f"{cls_.__module__}.{cls_.__name__}"
|
|
if cls.module_name_prefix and name.startswith(cls.module_name_prefix):
|
|
name = name[len(cls.module_name_prefix) :]
|
|
return name
|
|
|
|
@classmethod
|
|
def _export(
|
|
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
|
|
) -> Sequence[str]:
|
|
"""
|
|
Export the requested experiments, projects and models and return the list of artifact files
|
|
Always do the export on sorted items since the order of items influence hash
|
|
The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom
|
|
"""
|
|
artifacts = []
|
|
now = datetime.utcnow()
|
|
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
|
items = sorted(entities[cls_], key=attrgetter("name", "id"))
|
|
if not items:
|
|
continue
|
|
base_filename = cls._get_base_filename(cls_)
|
|
for item in items:
|
|
artifacts.extend(
|
|
cls._export_entity_related_data(
|
|
cls_, item, base_filename, writer, hash_
|
|
)
|
|
)
|
|
filename = base_filename + ".json"
|
|
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
|
with BytesIO() as f:
|
|
with cls.JsonLinesWriter(f) as w:
|
|
for item in items:
|
|
cls._cleanup_entity(cls_, item)
|
|
w.write(item.to_json())
|
|
data = f.getvalue()
|
|
hash_.update(data)
|
|
writer.writestr(filename, data)
|
|
|
|
if tag_entities:
|
|
cls._add_tag(items, now.strftime(cls.export_tag))
|
|
|
|
return artifacts
|
|
|
|
@staticmethod
|
|
def json_lines(file: IO[bytes]):
|
|
for line in file:
|
|
clean = (
|
|
line.decode("utf-8")
|
|
.rstrip("\r\n")
|
|
.strip()
|
|
.lstrip("[")
|
|
.rstrip(",]")
|
|
.strip()
|
|
)
|
|
if not clean:
|
|
continue
|
|
yield clean
|
|
|
|
@classmethod
|
|
def _import(
|
|
cls,
|
|
reader: ZipFile,
|
|
company_id: str = "",
|
|
user_id: str = None,
|
|
metadata: Mapping[str, Any] = None,
|
|
sort_tasks_by_last_updated: bool = True,
|
|
):
|
|
"""
|
|
Import entities and events from the zip file
|
|
Start from entities since event import will require the tasks already in DB
|
|
"""
|
|
event_file_ending = cls.events_file_suffix + ".json"
|
|
entity_files = (
|
|
fi
|
|
for fi in reader.filelist
|
|
if not fi.orig_filename.endswith(event_file_ending)
|
|
and fi.orig_filename != cls.metadata_filename
|
|
)
|
|
metadata = metadata or {}
|
|
tasks = []
|
|
for entity_file in entity_files:
|
|
with reader.open(entity_file) as f:
|
|
full_name = splitext(entity_file.orig_filename)[0]
|
|
print(f"Reading {reader.filename}:{full_name}...")
|
|
res = cls._import_entity(f, full_name, company_id, user_id, metadata)
|
|
if res:
|
|
tasks = res
|
|
|
|
if sort_tasks_by_last_updated:
|
|
tasks = sorted(tasks, key=attrgetter("last_update"))
|
|
|
|
for task in tasks:
|
|
events_file = first(
|
|
fi
|
|
for fi in reader.filelist
|
|
if fi.orig_filename.endswith(task.id + event_file_ending)
|
|
)
|
|
if not events_file:
|
|
continue
|
|
with reader.open(events_file) as f:
|
|
full_name = splitext(events_file.orig_filename)[0]
|
|
print(f"Reading {reader.filename}:{full_name}...")
|
|
cls._import_events(f, full_name, company_id, user_id)
|
|
|
|
@classmethod
|
|
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
|
|
module_name, _, class_name = full_name.rpartition(".")
|
|
if cls.module_name_prefix and not module_name.startswith(
|
|
cls.module_name_prefix
|
|
):
|
|
module_name = cls.module_name_prefix + module_name
|
|
module = importlib.import_module(module_name)
|
|
return getattr(module, class_name)
|
|
|
|
@staticmethod
|
|
def _upgrade_task_data(task_data: dict) -> dict:
|
|
"""
|
|
Migrate from execution/parameters and model_desc to hyperparams and configuration fiields
|
|
Upgrade artifacts list to dict
|
|
Migrate from execution.model and output.model to the new models field
|
|
Move docker_cmd contents into the container field
|
|
:param task_data: Upgraded in place
|
|
:return: The upgraded task data
|
|
"""
|
|
for old_param_field, new_param_field, default_section in (
|
|
("execution.parameters", "hyperparams", hyperparams_default_section),
|
|
("execution.model_desc", "configuration", None),
|
|
):
|
|
legacy_path = old_param_field.split(".")
|
|
legacy = nested_get(task_data, legacy_path)
|
|
if legacy:
|
|
for full_name, value in legacy.items():
|
|
section, name = split_param_name(full_name, default_section)
|
|
new_path = list(filter(None, (new_param_field, section, name)))
|
|
if not nested_get(task_data, new_path):
|
|
new_param = dict(
|
|
name=name, type=hyperparams_legacy_type, value=str(value)
|
|
)
|
|
if section is not None:
|
|
new_param["section"] = section
|
|
nested_set(task_data, path=new_path, value=new_param)
|
|
nested_delete(task_data, legacy_path)
|
|
|
|
artifacts_path = ("execution", "artifacts")
|
|
artifacts = nested_get(task_data, artifacts_path)
|
|
if isinstance(artifacts, list):
|
|
nested_set(
|
|
task_data,
|
|
path=artifacts_path,
|
|
value={get_artifact_id(a): a for a in artifacts},
|
|
)
|
|
|
|
models = task_data.get("models", {})
|
|
now = datetime.utcnow()
|
|
for old_field, type_ in (
|
|
("execution.model", TaskModelTypes.input),
|
|
("output.model", TaskModelTypes.output),
|
|
):
|
|
old_path = old_field.split(".")
|
|
old_model = nested_get(task_data, old_path)
|
|
new_models = models.get(type_, [])
|
|
name = TaskModelNames[type_]
|
|
if old_model and not any(
|
|
m
|
|
for m in new_models
|
|
if m.get("model") == old_model or m.get("name") == name
|
|
):
|
|
model_item = {"model": old_model, "name": name, "updated": now}
|
|
if type_ == TaskModelTypes.input:
|
|
new_models = [model_item, *new_models]
|
|
else:
|
|
new_models = [*new_models, model_item]
|
|
models[type_] = new_models
|
|
nested_delete(task_data, old_path)
|
|
task_data["models"] = models
|
|
|
|
docker_cmd_path = ("execution", "docker_cmd")
|
|
docker_cmd = nested_get(task_data, docker_cmd_path)
|
|
if docker_cmd and not task_data.get("container"):
|
|
image, _, arguments = docker_cmd.partition(" ")
|
|
task_data["container"] = {"image": image, "arguments": arguments}
|
|
nested_delete(task_data, docker_cmd_path)
|
|
|
|
return task_data
|
|
|
|
@classmethod
|
|
def _import_entity(
|
|
cls,
|
|
f: IO[bytes],
|
|
full_name: str,
|
|
company_id: str,
|
|
user_id: str,
|
|
metadata: Mapping[str, Any],
|
|
) -> Optional[Sequence[Task]]:
|
|
cls_ = cls._get_entity_type(full_name)
|
|
print(f"Writing {cls_.__name__.lower()}s into database")
|
|
tasks = []
|
|
override_project_count = 0
|
|
for item in cls.json_lines(f):
|
|
if cls_ == cls.task_cls:
|
|
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
|
|
|
|
doc = cls_.from_json(item, created=True)
|
|
if hasattr(doc, "user"):
|
|
doc.user = user_id
|
|
if hasattr(doc, "company"):
|
|
doc.company = company_id
|
|
if isinstance(doc, cls.project_cls):
|
|
override_project_name = metadata.get("project_name", None)
|
|
if override_project_name:
|
|
if override_project_count:
|
|
override_project_name = (
|
|
f"{override_project_name} {override_project_count + 1}"
|
|
)
|
|
override_project_count += 1
|
|
doc.name = override_project_name
|
|
|
|
doc.logo_url = metadata.get("logo_url", None)
|
|
doc.logo_blob = metadata.get("logo_blob", None)
|
|
|
|
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
|
|
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
|
|
)
|
|
|
|
doc.save()
|
|
|
|
if isinstance(doc, cls.task_cls):
|
|
tasks.append(doc)
|
|
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
|
|
|
if tasks:
|
|
return tasks
|
|
|
|
@classmethod
|
|
def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _):
|
|
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
|
print(f"Writing events for task {task_id} into database")
|
|
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
|
events = [json.loads(item) for item in events_chunk]
|
|
cls.event_bll.add_events(
|
|
company_id, events=events, worker="", allow_locked_tasks=True
|
|
)
|