diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index 0cc6fdc..db96e91 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -65,6 +65,14 @@ from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper replace_s3_scheme = os.getenv("CLEARML_REPLACE_S3_SCHEME") +def _print(msg: str): + time = datetime.now().isoformat(sep=" ", timespec="seconds") + print(f"{time} {msg}") + + +UrlTranslation = Tuple[str, str] + + class PrePopulate: module_name_prefix = "apiserver." event_bll = EventBLL() @@ -163,7 +171,7 @@ class PrePopulate: return True, files except Exception as ex: - print("Error reading map file. " + str(ex)) + _print("Error reading map file. " + str(ex)) return True, files return False, files @@ -204,7 +212,7 @@ class PrePopulate: return False fileserver_links = [a for a in artifacts if is_fileserver_link(a)] - print( + _print( f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total" ) @@ -216,81 +224,114 @@ class PrePopulate: filename: str, experiments: Sequence[str] = None, projects: Sequence[str] = None, + company: str = None, artifacts_path: str = None, task_statuses: Sequence[str] = None, tag_exported_entities: bool = False, metadata: Mapping[str, Any] = None, export_events: bool = True, export_users: bool = False, + project_split: bool = False, + url_trans: UrlTranslation = None, ) -> Sequence[str]: cls._init_entity_types() if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)): raise ValueError("Invalid task statuses") - file = Path(filename) - if not (experiments or projects): - projects = cls.project_cls.objects(parent=None).scalar("id") - - entities = cls._resolve_entities( - experiments=experiments, projects=projects, task_statuses=task_statuses - ) - - hash_ = hashlib.md5() - if metadata: - meta_str = json.dumps(metadata) - hash_.update(meta_str.encode()) - metadata_hash = hash_.hexdigest() - else: - meta_str, metadata_hash = "", "" - - map_file = file.with_suffix(".map") - updated, old_files = cls._check_for_update( - map_file, entities=entities, metadata_hash=metadata_hash - ) - if not updated: - print(f"There are no updates from the last export") - return old_files - - for old in old_files: - old_path = Path(old) - if old_path.is_file(): - old_path.unlink() - - with ZipFile(file, **cls.zip_args) as zfile: - if metadata: - zfile.writestr(cls.metadata_filename, meta_str) - if export_users: - cls._export_users(zfile) - artifacts = cls._export( - zfile, - entities=entities, - hash_=hash_, - tag_entities=tag_exported_entities, - export_events=export_events, - cleanup_users=not export_users, + def export_to_zip_core(file_base_name: Path, projects_: Sequence[str]): + entities = cls._resolve_entities( + experiments=experiments, projects=projects_, task_statuses=task_statuses ) - file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}") - file.replace(file_with_hash) - created_files = [str(file_with_hash)] + hash_ = hashlib.md5() + if metadata: + meta_str = json.dumps(metadata) + hash_.update(meta_str.encode()) + metadata_hash = hash_.hexdigest() + else: + meta_str, metadata_hash = "", "" - artifacts = cls._filter_artifacts(artifacts) - if artifacts and artifacts_path and os.path.isdir(artifacts_path): - artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext) - with ZipFile(artifacts_file, **cls.zip_args) as zfile: - cls._export_artifacts(zfile, artifacts, artifacts_path) - created_files.append(str(artifacts_file)) + map_file = file_base_name.with_suffix(".map") + updated, old_files = cls._check_for_update( + map_file, entities=entities, metadata_hash=metadata_hash + ) + if not updated: + _print(f"There are no updates from the last export") + return old_files - cls._write_update_file( - map_file, - entities=entities, - created_files=created_files, - metadata_hash=metadata_hash, - ) + for old in old_files: + old_path = Path(old) + if old_path.is_file(): + old_path.unlink() - if created_files: - print("Created files:\n" + "\n".join(file for file in created_files)) + temp_file = file_base_name.with_suffix(file_base_name.suffix + "$") + try: + with ZipFile(temp_file, **cls.zip_args) as zfile: + if metadata: + zfile.writestr(cls.metadata_filename, meta_str) + if export_users: + cls._export_users(zfile) + artifacts = cls._export( + zfile, + entities=entities, + hash_=hash_, + tag_entities=tag_exported_entities, + export_events=export_events, + cleanup_users=not export_users, + url_trans=url_trans, + ) + except: + temp_file.unlink(missing_ok=True) + raise + + file_with_hash = file_base_name.with_stem( + f"{file_base_name.stem}_{hash_.hexdigest()}" + ) + temp_file.replace(file_with_hash) + files = [str(file_with_hash)] + + artifacts = cls._filter_artifacts(artifacts) + if artifacts and artifacts_path and os.path.isdir(artifacts_path): + artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext) + with ZipFile(artifacts_file, **cls.zip_args) as zfile: + cls._export_artifacts(zfile, artifacts, artifacts_path) + files.append(str(artifacts_file)) + + cls._write_update_file( + map_file, + entities=entities, + created_files=files, + metadata_hash=metadata_hash, + ) + + if files: + _print("Created files:\n" + "\n".join(file for file in files)) + + return files + + filename = Path(filename) + if not (experiments or projects): + query = dict(parent=None) + if company: + query["company"] = company + projects = list(cls.project_cls.objects(**query).scalar("id")) + # projects.append(None) + + if projects and project_split: + created_files = list( + chain.from_iterable( + export_to_zip_core( + file_base_name=filename.with_stem(f"{filename.stem}_{pid}"), + projects_=[pid], + ) + for pid in projects + ) + ) + else: + created_files = export_to_zip_core( + file_base_name=filename, projects_=projects + ) return created_files @@ -320,8 +361,10 @@ class PrePopulate: meta_user_id = metadata.get("user_id", "") meta_user_name = metadata.get("user_name", "") user_id, user_name = meta_user_id, meta_user_name - except Exception: - pass + except Exception as ex: + _print( + f"Error getting metadata from {cls.metadata_filename}: {str(ex)}" + ) # Make sure we won't end up with an invalid company ID if company_id is None: @@ -347,7 +390,7 @@ class PrePopulate: if artifacts_path and os.path.isdir(artifacts_path): artifacts_file = Path(filename).with_suffix(cls.artifacts_ext) if artifacts_file.is_file(): - print(f"Unzipping artifacts into {artifacts_path}") + _print(f"Unzipping artifacts into {artifacts_path}") with ZipFile(artifacts_file) as zfile: zfile.extractall(artifacts_path) @@ -370,7 +413,7 @@ class PrePopulate: base_file_name, _, old_hash = file.stem.rpartition("_") new_hash = hash_.hexdigest() if old_hash == new_hash: - print(f"The file {filename} was not updated") + _print(f"The file {filename} was not updated") temp_file.unlink() return [] @@ -384,7 +427,7 @@ class PrePopulate: artifacts_file.replace(new_artifacts) upadated.append(str(new_artifacts)) - print(f"File {str(file)} replaced with {str(new_file)}") + _print(f"File {str(file)} replaced with {str(new_file)}") file.unlink() return upadated @@ -446,12 +489,12 @@ class PrePopulate: not_found = missing - set(resolved_by_name) if not_found: - print(f"ERROR: no match for {', '.join(not_found)}") + _print(f"ERROR: no match for {', '.join(not_found)}") exit(1) duplicates = [k for k, v in resolved_by_name.items() if len(v) > 1] if duplicates: - print(f"ERROR: more than one match for {', '.join(duplicates)}") + _print(f"ERROR: more than one match for {', '.join(duplicates)}") exit(1) def get_new_items(input_: Iterable) -> list: @@ -489,20 +532,24 @@ class PrePopulate: return prefixes = [ - cls.ParentPrefix(prefix=f"{project.name.rpartition('/')[0]}/", path=project.path) + cls.ParentPrefix( + prefix=f"{project.name.rpartition('/')[0]}/", path=project.path + ) for project in orphans ] prefixes.sort(key=lambda p: len(p.path), reverse=True) for project in projects: - prefix = first(pref for pref in prefixes if project.path[:len(pref.path)] == pref.path) + prefix = first( + pref for pref in prefixes if project.path[: len(pref.path)] == pref.path + ) if not prefix: continue - project.path = project.path[len(prefix.path):] + project.path = project.path[len(prefix.path) :] if not project.path: project.parent = None project.name = project.name.removeprefix(prefix.prefix) - # print( + # _print( # f"ERROR: the following projects are exported without their parents: {orphans}" # ) # exit(1) @@ -518,16 +565,20 @@ class PrePopulate: entities: Dict[Any] = defaultdict(set) if projects: - print("Reading projects...") - projects = project_ids_with_children(projects) - entities[cls.project_cls].update( - cls._resolve_entity_type(cls.project_cls, projects) - ) - print("--> Reading project experiments...") + _print("Reading projects...") + root = None in projects + projects = [p for p in projects if p] + if projects: + projects = project_ids_with_children(projects) + entities[cls.project_cls].update( + cls._resolve_entity_type(cls.project_cls, projects) + ) + _print("--> Reading project experiments...") + p_ids = list(set(p.id for p in entities[cls.project_cls])) + if root: + p_ids.append(None) query = Q( - project__in=list( - set(filter(None, (p.id for p in entities[cls.project_cls]))) - ), + project__in=p_ids, system_tags__nin=[EntityVisibility.archived.value], ) if task_statuses: @@ -538,9 +589,11 @@ class PrePopulate: ) if experiments: - print("Reading experiments...") - entities[cls.task_cls].update(cls._resolve_entity_type(cls.task_cls, experiments)) - print("--> Reading experiments projects...") + _print("Reading experiments...") + entities[cls.task_cls].update( + cls._resolve_entity_type(cls.task_cls, experiments) + ) + _print("--> Reading experiments projects...") objs = cls.project_cls.objects( id__in=list( set(filter(None, (p.project for p in entities[cls.task_cls]))) @@ -560,7 +613,7 @@ class PrePopulate: ) model_ids = {tm.model for tm in task_models} if model_ids: - print("Reading models...") + _print("Reading models...") entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids))) # noinspection PyTypeChecker @@ -625,22 +678,41 @@ class PrePopulate: except AttributeError: pass + @staticmethod + def _translate_url(url_: str, url_trans: UrlTranslation) -> str: + if not (url_ and url_trans): + return url_ + + source, target = url_trans + if not url_.startswith(source): + return url_ + + return target + url_[len(source):] + @classmethod def _export_task_events( - cls, task: Task, base_filename: str, writer: ZipFile, hash_ + cls, + task: Task, + base_filename: str, + writer: ZipFile, + hash_, + url_trans: UrlTranslation, ) -> Sequence[str]: artifacts = [] filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json" - print(f"Writing task events into {writer.filename}:{filename}") + _print(f"Writing task events into {writer.filename}:{filename}") + with BytesIO() as f: with cls.JsonLinesWriter(f) as w: scroll_id = None + events_count = 0 while True: res = cls.event_bll.get_task_events( company_id=task.company, task_id=task.id, event_type=EventType.all, scroll_id=scroll_id, + size=10_000, ) if not res.events: break @@ -650,16 +722,22 @@ class PrePopulate: if event_type == EventType.metrics_image.value: url = cls._get_fixed_url(event.get("url")) if url: - event["url"] = url artifacts.append(url) + event["url"] = cls._translate_url(url, url_trans) elif event_type == EventType.metrics_plot.value: plot_str: str = event.get("plot_str", "") - for match in cls.img_source_regex.findall(plot_str): - url = cls._get_fixed_url(match) - if match != url: - plot_str = plot_str.replace(match, url) - artifacts.append(url) + if plot_str: + for match in cls.img_source_regex.findall(plot_str): + url = cls._get_fixed_url(match) + artifacts.append(url) + new_url = cls._translate_url(url, url_trans) + if match != new_url: + plot_str = plot_str.replace(match, new_url) + event["plot_str"] = plot_str w.write(json.dumps(event)) + events_count += 1 + _print(f"Got {events_count} events for task {task.id}") + _print(f"Writing {events_count} events for task {task.id}") data = f.getvalue() hash_.update(data) writer.writestr(filename, data) @@ -677,53 +755,62 @@ class PrePopulate: fixed.host += ".s3.amazonaws.com" return fixed.url except Exception as ex: - print(f"Failed processing link {url}. " + str(ex)) + _print(f"Failed processing link {url}. " + str(ex)) return url @classmethod def _export_entity_related_data( - cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_ + cls, + entity_cls, + entity, + base_filename: str, + writer: ZipFile, + hash_, + url_trans: UrlTranslation, ): if entity_cls == cls.task_cls: return [ - *cls._get_task_output_artifacts(entity), - *cls._export_task_events(entity, base_filename, writer, hash_), + *cls._get_task_output_artifacts(entity, url_trans), + *cls._export_task_events( + entity, base_filename, writer, hash_, url_trans + ), ] if entity_cls == cls.model_cls: - entity.uri = cls._get_fixed_url(entity.uri) - return [entity.uri] if entity.uri else [] + url = cls._get_fixed_url(entity.uri) + entity.uri = cls._translate_url(url, url_trans) + return [url] if url else [] return [] @classmethod - def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]: + def _get_task_output_artifacts(cls, task: Task, url_trans: UrlTranslation) -> Sequence[str]: if not task.execution.artifacts: return [] + artifact_urls = [] for a in task.execution.artifacts.values(): if a.mode == ArtifactModes.output: - a.uri = cls._get_fixed_url(a.uri) + url = cls._get_fixed_url(a.uri) + a.uri = cls._translate_url(url, url_trans) + if url and a.mode == ArtifactModes.output: + artifact_urls.append(url) - return [ - a.uri - for a in task.execution.artifacts.values() - if a.mode == ArtifactModes.output and a.uri - ] + return artifact_urls @classmethod def _export_artifacts( cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str ): unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts) - print(f"Writing {len(unique_paths)} artifacts into {writer.filename}") + _print(f"Writing {len(unique_paths)} artifacts into {writer.filename}") for path in unique_paths: path = path.lstrip("/") full_path = os.path.join(artifacts_path, path) if os.path.isfile(full_path): writer.write(full_path, path) else: - print(f"Artifact {full_path} not found") + _print(f"Artifact {full_path} not found") @classmethod def _export_users(cls, writer: ZipFile): @@ -742,7 +829,7 @@ class PrePopulate: return auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users} - print(f"Writing {len(auth_users)} users into {writer.filename}") + _print(f"Writing {len(auth_users)} users into {writer.filename}") data = {} for field, users in (("auth", auth_users), ("backend", be_users)): with BytesIO() as f: @@ -773,6 +860,7 @@ class PrePopulate: tag_entities: bool = False, export_events: bool = True, cleanup_users: bool = True, + url_trans: UrlTranslation = None, ) -> Sequence[str]: """ Export the requested experiments, projects and models and return the list of artifact files @@ -780,7 +868,7 @@ class PrePopulate: The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom """ artifacts = [] - now = datetime.utcnow() + now = datetime.now(timezone.utc) for cls_ in sorted(entities, key=attrgetter("__name__")): items = sorted(entities[cls_], key=attrgetter("name", "id")) if not items: @@ -790,11 +878,11 @@ class PrePopulate: for item in items: artifacts.extend( cls._export_entity_related_data( - cls_, item, base_filename, writer, hash_ + cls_, item, base_filename, writer, hash_, url_trans ) ) filename = base_filename + ".json" - print(f"Writing {len(items)} items into {writer.filename}:{filename}") + _print(f"Writing {len(items)} items into {writer.filename}:{filename}") with BytesIO() as f: with cls.JsonLinesWriter(f) as w: for item in items: @@ -968,7 +1056,7 @@ class PrePopulate: for entity_file in entity_files: with reader.open(entity_file) as f: full_name = splitext(entity_file.orig_filename)[0] - print(f"Reading {reader.filename}:{full_name}...") + _print(f"Reading {reader.filename}:{full_name}...") res = cls._import_entity( f, full_name=full_name, @@ -996,7 +1084,7 @@ class PrePopulate: continue with reader.open(events_file) as f: full_name = splitext(events_file.orig_filename)[0] - print(f"Reading {reader.filename}:{full_name}...") + _print(f"Reading {reader.filename}:{full_name}...") cls._import_events(f, company_id, task.user, task.id) @classmethod @@ -1082,14 +1170,16 @@ class PrePopulate: ) models = task_data.get("models", {}) - now = datetime.utcnow() + now = datetime.now(timezone.utc) for old_field, type_ in ( ("execution.model", TaskModelTypes.input), ("output.model", TaskModelTypes.output), ): old_path = old_field.split(".") old_model = nested_get(task_data, old_path) - new_models = [m for m in models.get(type_, []) if m.get("model") is not None] + new_models = [ + m for m in models.get(type_, []) if m.get("model") is not None + ] name = TaskModelNames[type_] if old_model and not any( m @@ -1127,7 +1217,7 @@ class PrePopulate: ) -> Optional[Sequence[Task]]: user_mapping = user_mapping or {} cls_ = cls._get_entity_type(full_name) - print(f"Writing {cls_.__name__.lower()}s into database") + _print(f"Writing {cls_.__name__.lower()}s into database") tasks = [] override_project_count = 0 data_upgrade_funcs: Mapping[Type, Callable] = { @@ -1164,21 +1254,23 @@ class PrePopulate: doc.logo_blob = metadata.get("logo_blob", None) cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update( - set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}" + set__name=f"{doc.name}_{datetime.now(timezone.utc).strftime('%Y-%m-%d_%H-%M-%S')}" ) doc.save() if isinstance(doc, cls.task_cls): tasks.append(doc) - cls.event_bll.delete_task_events(company_id, doc.id, wait_for_delete=True) + cls.event_bll.delete_task_events( + company_id, doc.id, wait_for_delete=True + ) if tasks: return tasks @classmethod def _import_events(cls, f: IO[bytes], company_id: str, user_id: str, task_id: str): - print(f"Writing events for task {task_id} into database") + _print(f"Writing events for task {task_id} into database") for events_chunk in chunked_iter(cls.json_lines(f), 1000): events = [json.loads(item) for item in events_chunk] for ev in events: diff --git a/apiserver/tests/automated/test_task_artifacts.py b/apiserver/tests/automated/test_task_artifacts.py index 9c24f69..1c99c61 100644 --- a/apiserver/tests/automated/test_task_artifacts.py +++ b/apiserver/tests/automated/test_task_artifacts.py @@ -32,8 +32,8 @@ class TestTasksArtifacts(TestService): # test edit artifacts = [ - dict(key="bb", type="str", uri="test1", mode="output"), - dict(key="aa", type="int", uri="test2", mode="input"), + dict(key="bb", type="str", uri="http://files.clear.ml/test1", mode="output"), + dict(key="aa", type="int", uri="http://files.clear.ml/test2", mode="input"), ] self.api.tasks.edit(task=task, execution={"artifacts": artifacts}) res = self.api.tasks.get_by_id(task=task).task diff --git a/apiserver/tests/automated/test_task_plots.py b/apiserver/tests/automated/test_task_plots.py index 6542912..fbefc8f 100644 --- a/apiserver/tests/automated/test_task_plots.py +++ b/apiserver/tests/automated/test_task_plots.py @@ -14,6 +14,12 @@ class TestTaskPlots(TestService): @staticmethod def _create_task_event(task, iteration, **kwargs): + plot_str = kwargs.get("plot_str") + if plot_str: + if not plot_str.startswith("http"): + plot_str = "http://files.clear.ml/" + plot_str + kwargs["plot_str"] = '{"source": "' + plot_str + '"}' + return { "worker": "test", "type": "plot",