Add data_tool export improvements including 'company' flag, increased batch size for performance, date-time to log strings, more logs, an option to create a separate zip file per root project, an option to translate urls during tool export

This commit is contained in:
clearml 2025-06-04 11:43:31 +03:00
parent bf00441146
commit a7e340212f
3 changed files with 219 additions and 121 deletions

View File

@ -65,6 +65,14 @@ from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
replace_s3_scheme = os.getenv("CLEARML_REPLACE_S3_SCHEME")
def _print(msg: str):
time = datetime.now().isoformat(sep=" ", timespec="seconds")
print(f"{time} {msg}")
UrlTranslation = Tuple[str, str]
class PrePopulate:
module_name_prefix = "apiserver."
event_bll = EventBLL()
@ -163,7 +171,7 @@ class PrePopulate:
return True, files
except Exception as ex:
print("Error reading map file. " + str(ex))
_print("Error reading map file. " + str(ex))
return True, files
return False, files
@ -204,7 +212,7 @@ class PrePopulate:
return False
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
print(
_print(
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
)
@ -216,81 +224,114 @@ class PrePopulate:
filename: str,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
company: str = None,
artifacts_path: str = None,
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
export_events: bool = True,
export_users: bool = False,
project_split: bool = False,
url_trans: UrlTranslation = None,
) -> Sequence[str]:
cls._init_entity_types()
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses")
file = Path(filename)
if not (experiments or projects):
projects = cls.project_cls.objects(parent=None).scalar("id")
entities = cls._resolve_entities(
experiments=experiments, projects=projects, task_statuses=task_statuses
)
hash_ = hashlib.md5()
if metadata:
meta_str = json.dumps(metadata)
hash_.update(meta_str.encode())
metadata_hash = hash_.hexdigest()
else:
meta_str, metadata_hash = "", ""
map_file = file.with_suffix(".map")
updated, old_files = cls._check_for_update(
map_file, entities=entities, metadata_hash=metadata_hash
)
if not updated:
print(f"There are no updates from the last export")
return old_files
for old in old_files:
old_path = Path(old)
if old_path.is_file():
old_path.unlink()
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
if export_users:
cls._export_users(zfile)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
export_events=export_events,
cleanup_users=not export_users,
def export_to_zip_core(file_base_name: Path, projects_: Sequence[str]):
entities = cls._resolve_entities(
experiments=experiments, projects=projects_, task_statuses=task_statuses
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
file.replace(file_with_hash)
created_files = [str(file_with_hash)]
hash_ = hashlib.md5()
if metadata:
meta_str = json.dumps(metadata)
hash_.update(meta_str.encode())
metadata_hash = hash_.hexdigest()
else:
meta_str, metadata_hash = "", ""
artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file))
map_file = file_base_name.with_suffix(".map")
updated, old_files = cls._check_for_update(
map_file, entities=entities, metadata_hash=metadata_hash
)
if not updated:
_print(f"There are no updates from the last export")
return old_files
cls._write_update_file(
map_file,
entities=entities,
created_files=created_files,
metadata_hash=metadata_hash,
)
for old in old_files:
old_path = Path(old)
if old_path.is_file():
old_path.unlink()
if created_files:
print("Created files:\n" + "\n".join(file for file in created_files))
temp_file = file_base_name.with_suffix(file_base_name.suffix + "$")
try:
with ZipFile(temp_file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
if export_users:
cls._export_users(zfile)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
export_events=export_events,
cleanup_users=not export_users,
url_trans=url_trans,
)
except:
temp_file.unlink(missing_ok=True)
raise
file_with_hash = file_base_name.with_stem(
f"{file_base_name.stem}_{hash_.hexdigest()}"
)
temp_file.replace(file_with_hash)
files = [str(file_with_hash)]
artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path)
files.append(str(artifacts_file))
cls._write_update_file(
map_file,
entities=entities,
created_files=files,
metadata_hash=metadata_hash,
)
if files:
_print("Created files:\n" + "\n".join(file for file in files))
return files
filename = Path(filename)
if not (experiments or projects):
query = dict(parent=None)
if company:
query["company"] = company
projects = list(cls.project_cls.objects(**query).scalar("id"))
# projects.append(None)
if projects and project_split:
created_files = list(
chain.from_iterable(
export_to_zip_core(
file_base_name=filename.with_stem(f"{filename.stem}_{pid}"),
projects_=[pid],
)
for pid in projects
)
)
else:
created_files = export_to_zip_core(
file_base_name=filename, projects_=projects
)
return created_files
@ -320,8 +361,10 @@ class PrePopulate:
meta_user_id = metadata.get("user_id", "")
meta_user_name = metadata.get("user_name", "")
user_id, user_name = meta_user_id, meta_user_name
except Exception:
pass
except Exception as ex:
_print(
f"Error getting metadata from {cls.metadata_filename}: {str(ex)}"
)
# Make sure we won't end up with an invalid company ID
if company_id is None:
@ -347,7 +390,7 @@ class PrePopulate:
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
print(f"Unzipping artifacts into {artifacts_path}")
_print(f"Unzipping artifacts into {artifacts_path}")
with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path)
@ -370,7 +413,7 @@ class PrePopulate:
base_file_name, _, old_hash = file.stem.rpartition("_")
new_hash = hash_.hexdigest()
if old_hash == new_hash:
print(f"The file {filename} was not updated")
_print(f"The file {filename} was not updated")
temp_file.unlink()
return []
@ -384,7 +427,7 @@ class PrePopulate:
artifacts_file.replace(new_artifacts)
upadated.append(str(new_artifacts))
print(f"File {str(file)} replaced with {str(new_file)}")
_print(f"File {str(file)} replaced with {str(new_file)}")
file.unlink()
return upadated
@ -446,12 +489,12 @@ class PrePopulate:
not_found = missing - set(resolved_by_name)
if not_found:
print(f"ERROR: no match for {', '.join(not_found)}")
_print(f"ERROR: no match for {', '.join(not_found)}")
exit(1)
duplicates = [k for k, v in resolved_by_name.items() if len(v) > 1]
if duplicates:
print(f"ERROR: more than one match for {', '.join(duplicates)}")
_print(f"ERROR: more than one match for {', '.join(duplicates)}")
exit(1)
def get_new_items(input_: Iterable) -> list:
@ -489,20 +532,24 @@ class PrePopulate:
return
prefixes = [
cls.ParentPrefix(prefix=f"{project.name.rpartition('/')[0]}/", path=project.path)
cls.ParentPrefix(
prefix=f"{project.name.rpartition('/')[0]}/", path=project.path
)
for project in orphans
]
prefixes.sort(key=lambda p: len(p.path), reverse=True)
for project in projects:
prefix = first(pref for pref in prefixes if project.path[:len(pref.path)] == pref.path)
prefix = first(
pref for pref in prefixes if project.path[: len(pref.path)] == pref.path
)
if not prefix:
continue
project.path = project.path[len(prefix.path):]
project.path = project.path[len(prefix.path) :]
if not project.path:
project.parent = None
project.name = project.name.removeprefix(prefix.prefix)
# print(
# _print(
# f"ERROR: the following projects are exported without their parents: {orphans}"
# )
# exit(1)
@ -518,16 +565,20 @@ class PrePopulate:
entities: Dict[Any] = defaultdict(set)
if projects:
print("Reading projects...")
projects = project_ids_with_children(projects)
entities[cls.project_cls].update(
cls._resolve_entity_type(cls.project_cls, projects)
)
print("--> Reading project experiments...")
_print("Reading projects...")
root = None in projects
projects = [p for p in projects if p]
if projects:
projects = project_ids_with_children(projects)
entities[cls.project_cls].update(
cls._resolve_entity_type(cls.project_cls, projects)
)
_print("--> Reading project experiments...")
p_ids = list(set(p.id for p in entities[cls.project_cls]))
if root:
p_ids.append(None)
query = Q(
project__in=list(
set(filter(None, (p.id for p in entities[cls.project_cls])))
),
project__in=p_ids,
system_tags__nin=[EntityVisibility.archived.value],
)
if task_statuses:
@ -538,9 +589,11 @@ class PrePopulate:
)
if experiments:
print("Reading experiments...")
entities[cls.task_cls].update(cls._resolve_entity_type(cls.task_cls, experiments))
print("--> Reading experiments projects...")
_print("Reading experiments...")
entities[cls.task_cls].update(
cls._resolve_entity_type(cls.task_cls, experiments)
)
_print("--> Reading experiments projects...")
objs = cls.project_cls.objects(
id__in=list(
set(filter(None, (p.project for p in entities[cls.task_cls])))
@ -560,7 +613,7 @@ class PrePopulate:
)
model_ids = {tm.model for tm in task_models}
if model_ids:
print("Reading models...")
_print("Reading models...")
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
# noinspection PyTypeChecker
@ -625,22 +678,41 @@ class PrePopulate:
except AttributeError:
pass
@staticmethod
def _translate_url(url_: str, url_trans: UrlTranslation) -> str:
if not (url_ and url_trans):
return url_
source, target = url_trans
if not url_.startswith(source):
return url_
return target + url_[len(source):]
@classmethod
def _export_task_events(
cls, task: Task, base_filename: str, writer: ZipFile, hash_
cls,
task: Task,
base_filename: str,
writer: ZipFile,
hash_,
url_trans: UrlTranslation,
) -> Sequence[str]:
artifacts = []
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
print(f"Writing task events into {writer.filename}:{filename}")
_print(f"Writing task events into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
scroll_id = None
events_count = 0
while True:
res = cls.event_bll.get_task_events(
company_id=task.company,
task_id=task.id,
event_type=EventType.all,
scroll_id=scroll_id,
size=10_000,
)
if not res.events:
break
@ -650,16 +722,22 @@ class PrePopulate:
if event_type == EventType.metrics_image.value:
url = cls._get_fixed_url(event.get("url"))
if url:
event["url"] = url
artifacts.append(url)
event["url"] = cls._translate_url(url, url_trans)
elif event_type == EventType.metrics_plot.value:
plot_str: str = event.get("plot_str", "")
for match in cls.img_source_regex.findall(plot_str):
url = cls._get_fixed_url(match)
if match != url:
plot_str = plot_str.replace(match, url)
artifacts.append(url)
if plot_str:
for match in cls.img_source_regex.findall(plot_str):
url = cls._get_fixed_url(match)
artifacts.append(url)
new_url = cls._translate_url(url, url_trans)
if match != new_url:
plot_str = plot_str.replace(match, new_url)
event["plot_str"] = plot_str
w.write(json.dumps(event))
events_count += 1
_print(f"Got {events_count} events for task {task.id}")
_print(f"Writing {events_count} events for task {task.id}")
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
@ -677,53 +755,62 @@ class PrePopulate:
fixed.host += ".s3.amazonaws.com"
return fixed.url
except Exception as ex:
print(f"Failed processing link {url}. " + str(ex))
_print(f"Failed processing link {url}. " + str(ex))
return url
@classmethod
def _export_entity_related_data(
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
cls,
entity_cls,
entity,
base_filename: str,
writer: ZipFile,
hash_,
url_trans: UrlTranslation,
):
if entity_cls == cls.task_cls:
return [
*cls._get_task_output_artifacts(entity),
*cls._export_task_events(entity, base_filename, writer, hash_),
*cls._get_task_output_artifacts(entity, url_trans),
*cls._export_task_events(
entity, base_filename, writer, hash_, url_trans
),
]
if entity_cls == cls.model_cls:
entity.uri = cls._get_fixed_url(entity.uri)
return [entity.uri] if entity.uri else []
url = cls._get_fixed_url(entity.uri)
entity.uri = cls._translate_url(url, url_trans)
return [url] if url else []
return []
@classmethod
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
def _get_task_output_artifacts(cls, task: Task, url_trans: UrlTranslation) -> Sequence[str]:
if not task.execution.artifacts:
return []
artifact_urls = []
for a in task.execution.artifacts.values():
if a.mode == ArtifactModes.output:
a.uri = cls._get_fixed_url(a.uri)
url = cls._get_fixed_url(a.uri)
a.uri = cls._translate_url(url, url_trans)
if url and a.mode == ArtifactModes.output:
artifact_urls.append(url)
return [
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
]
return artifact_urls
@classmethod
def _export_artifacts(
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
):
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
_print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
for path in unique_paths:
path = path.lstrip("/")
full_path = os.path.join(artifacts_path, path)
if os.path.isfile(full_path):
writer.write(full_path, path)
else:
print(f"Artifact {full_path} not found")
_print(f"Artifact {full_path} not found")
@classmethod
def _export_users(cls, writer: ZipFile):
@ -742,7 +829,7 @@ class PrePopulate:
return
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
print(f"Writing {len(auth_users)} users into {writer.filename}")
_print(f"Writing {len(auth_users)} users into {writer.filename}")
data = {}
for field, users in (("auth", auth_users), ("backend", be_users)):
with BytesIO() as f:
@ -773,6 +860,7 @@ class PrePopulate:
tag_entities: bool = False,
export_events: bool = True,
cleanup_users: bool = True,
url_trans: UrlTranslation = None,
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
@ -780,7 +868,7 @@ class PrePopulate:
The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom
"""
artifacts = []
now = datetime.utcnow()
now = datetime.now(timezone.utc)
for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("name", "id"))
if not items:
@ -790,11 +878,11 @@ class PrePopulate:
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
cls_, item, base_filename, writer, hash_, url_trans
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
_print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
@ -968,7 +1056,7 @@ class PrePopulate:
for entity_file in entity_files:
with reader.open(entity_file) as f:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
_print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(
f,
full_name=full_name,
@ -996,7 +1084,7 @@ class PrePopulate:
continue
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
_print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, company_id, task.user, task.id)
@classmethod
@ -1082,14 +1170,16 @@ class PrePopulate:
)
models = task_data.get("models", {})
now = datetime.utcnow()
now = datetime.now(timezone.utc)
for old_field, type_ in (
("execution.model", TaskModelTypes.input),
("output.model", TaskModelTypes.output),
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = [m for m in models.get(type_, []) if m.get("model") is not None]
new_models = [
m for m in models.get(type_, []) if m.get("model") is not None
]
name = TaskModelNames[type_]
if old_model and not any(
m
@ -1127,7 +1217,7 @@ class PrePopulate:
) -> Optional[Sequence[Task]]:
user_mapping = user_mapping or {}
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
_print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
override_project_count = 0
data_upgrade_funcs: Mapping[Type, Callable] = {
@ -1164,21 +1254,23 @@ class PrePopulate:
doc.logo_blob = metadata.get("logo_blob", None)
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
set__name=f"{doc.name}_{datetime.now(timezone.utc).strftime('%Y-%m-%d_%H-%M-%S')}"
)
doc.save()
if isinstance(doc, cls.task_cls):
tasks.append(doc)
cls.event_bll.delete_task_events(company_id, doc.id, wait_for_delete=True)
cls.event_bll.delete_task_events(
company_id, doc.id, wait_for_delete=True
)
if tasks:
return tasks
@classmethod
def _import_events(cls, f: IO[bytes], company_id: str, user_id: str, task_id: str):
print(f"Writing events for task {task_id} into database")
_print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
for ev in events:

View File

@ -32,8 +32,8 @@ class TestTasksArtifacts(TestService):
# test edit
artifacts = [
dict(key="bb", type="str", uri="test1", mode="output"),
dict(key="aa", type="int", uri="test2", mode="input"),
dict(key="bb", type="str", uri="http://files.clear.ml/test1", mode="output"),
dict(key="aa", type="int", uri="http://files.clear.ml/test2", mode="input"),
]
self.api.tasks.edit(task=task, execution={"artifacts": artifacts})
res = self.api.tasks.get_by_id(task=task).task

View File

@ -14,6 +14,12 @@ class TestTaskPlots(TestService):
@staticmethod
def _create_task_event(task, iteration, **kwargs):
plot_str = kwargs.get("plot_str")
if plot_str:
if not plot_str.startswith("http"):
plot_str = "http://files.clear.ml/" + plot_str
kwargs["plot_str"] = '{"source": "' + plot_str + '"}'
return {
"worker": "test",
"type": "plot",