mirror of
				https://github.com/clearml/clearml-server
				synced 2025-06-26 23:15:47 +00:00 
			
		
		
		
	Add data_tool export improvements including 'company' flag, increased batch size for performance, date-time to log strings, more logs, an option to create a separate zip file per root project, an option to translate urls during tool export
This commit is contained in:
		
							parent
							
								
									bf00441146
								
							
						
					
					
						commit
						a7e340212f
					
				@ -65,6 +65,14 @@ from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
 | 
			
		||||
replace_s3_scheme = os.getenv("CLEARML_REPLACE_S3_SCHEME")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _print(msg: str):
 | 
			
		||||
    time = datetime.now().isoformat(sep=" ", timespec="seconds")
 | 
			
		||||
    print(f"{time} {msg}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
UrlTranslation = Tuple[str, str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrePopulate:
 | 
			
		||||
    module_name_prefix = "apiserver."
 | 
			
		||||
    event_bll = EventBLL()
 | 
			
		||||
@ -163,7 +171,7 @@ class PrePopulate:
 | 
			
		||||
                return True, files
 | 
			
		||||
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
            print("Error reading map file. " + str(ex))
 | 
			
		||||
            _print("Error reading map file. " + str(ex))
 | 
			
		||||
            return True, files
 | 
			
		||||
 | 
			
		||||
        return False, files
 | 
			
		||||
@ -204,7 +212,7 @@ class PrePopulate:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
 | 
			
		||||
        print(
 | 
			
		||||
        _print(
 | 
			
		||||
            f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -216,81 +224,114 @@ class PrePopulate:
 | 
			
		||||
        filename: str,
 | 
			
		||||
        experiments: Sequence[str] = None,
 | 
			
		||||
        projects: Sequence[str] = None,
 | 
			
		||||
        company: str = None,
 | 
			
		||||
        artifacts_path: str = None,
 | 
			
		||||
        task_statuses: Sequence[str] = None,
 | 
			
		||||
        tag_exported_entities: bool = False,
 | 
			
		||||
        metadata: Mapping[str, Any] = None,
 | 
			
		||||
        export_events: bool = True,
 | 
			
		||||
        export_users: bool = False,
 | 
			
		||||
        project_split: bool = False,
 | 
			
		||||
        url_trans: UrlTranslation = None,
 | 
			
		||||
    ) -> Sequence[str]:
 | 
			
		||||
        cls._init_entity_types()
 | 
			
		||||
 | 
			
		||||
        if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
 | 
			
		||||
            raise ValueError("Invalid task statuses")
 | 
			
		||||
 | 
			
		||||
        file = Path(filename)
 | 
			
		||||
        if not (experiments or projects):
 | 
			
		||||
            projects = cls.project_cls.objects(parent=None).scalar("id")
 | 
			
		||||
 | 
			
		||||
        entities = cls._resolve_entities(
 | 
			
		||||
            experiments=experiments, projects=projects, task_statuses=task_statuses
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hash_ = hashlib.md5()
 | 
			
		||||
        if metadata:
 | 
			
		||||
            meta_str = json.dumps(metadata)
 | 
			
		||||
            hash_.update(meta_str.encode())
 | 
			
		||||
            metadata_hash = hash_.hexdigest()
 | 
			
		||||
        else:
 | 
			
		||||
            meta_str, metadata_hash = "", ""
 | 
			
		||||
 | 
			
		||||
        map_file = file.with_suffix(".map")
 | 
			
		||||
        updated, old_files = cls._check_for_update(
 | 
			
		||||
            map_file, entities=entities, metadata_hash=metadata_hash
 | 
			
		||||
        )
 | 
			
		||||
        if not updated:
 | 
			
		||||
            print(f"There are no updates from the last export")
 | 
			
		||||
            return old_files
 | 
			
		||||
 | 
			
		||||
        for old in old_files:
 | 
			
		||||
            old_path = Path(old)
 | 
			
		||||
            if old_path.is_file():
 | 
			
		||||
                old_path.unlink()
 | 
			
		||||
 | 
			
		||||
        with ZipFile(file, **cls.zip_args) as zfile:
 | 
			
		||||
            if metadata:
 | 
			
		||||
                zfile.writestr(cls.metadata_filename, meta_str)
 | 
			
		||||
            if export_users:
 | 
			
		||||
                cls._export_users(zfile)
 | 
			
		||||
            artifacts = cls._export(
 | 
			
		||||
                zfile,
 | 
			
		||||
                entities=entities,
 | 
			
		||||
                hash_=hash_,
 | 
			
		||||
                tag_entities=tag_exported_entities,
 | 
			
		||||
                export_events=export_events,
 | 
			
		||||
                cleanup_users=not export_users,
 | 
			
		||||
        def export_to_zip_core(file_base_name: Path, projects_: Sequence[str]):
 | 
			
		||||
            entities = cls._resolve_entities(
 | 
			
		||||
                experiments=experiments, projects=projects_, task_statuses=task_statuses
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
 | 
			
		||||
        file.replace(file_with_hash)
 | 
			
		||||
        created_files = [str(file_with_hash)]
 | 
			
		||||
            hash_ = hashlib.md5()
 | 
			
		||||
            if metadata:
 | 
			
		||||
                meta_str = json.dumps(metadata)
 | 
			
		||||
                hash_.update(meta_str.encode())
 | 
			
		||||
                metadata_hash = hash_.hexdigest()
 | 
			
		||||
            else:
 | 
			
		||||
                meta_str, metadata_hash = "", ""
 | 
			
		||||
 | 
			
		||||
        artifacts = cls._filter_artifacts(artifacts)
 | 
			
		||||
        if artifacts and artifacts_path and os.path.isdir(artifacts_path):
 | 
			
		||||
            artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
 | 
			
		||||
            with ZipFile(artifacts_file, **cls.zip_args) as zfile:
 | 
			
		||||
                cls._export_artifacts(zfile, artifacts, artifacts_path)
 | 
			
		||||
            created_files.append(str(artifacts_file))
 | 
			
		||||
            map_file = file_base_name.with_suffix(".map")
 | 
			
		||||
            updated, old_files = cls._check_for_update(
 | 
			
		||||
                map_file, entities=entities, metadata_hash=metadata_hash
 | 
			
		||||
            )
 | 
			
		||||
            if not updated:
 | 
			
		||||
                _print(f"There are no updates from the last export")
 | 
			
		||||
                return old_files
 | 
			
		||||
 | 
			
		||||
        cls._write_update_file(
 | 
			
		||||
            map_file,
 | 
			
		||||
            entities=entities,
 | 
			
		||||
            created_files=created_files,
 | 
			
		||||
            metadata_hash=metadata_hash,
 | 
			
		||||
        )
 | 
			
		||||
            for old in old_files:
 | 
			
		||||
                old_path = Path(old)
 | 
			
		||||
                if old_path.is_file():
 | 
			
		||||
                    old_path.unlink()
 | 
			
		||||
 | 
			
		||||
        if created_files:
 | 
			
		||||
            print("Created files:\n" + "\n".join(file for file in created_files))
 | 
			
		||||
            temp_file = file_base_name.with_suffix(file_base_name.suffix + "$")
 | 
			
		||||
            try:
 | 
			
		||||
                with ZipFile(temp_file, **cls.zip_args) as zfile:
 | 
			
		||||
                    if metadata:
 | 
			
		||||
                        zfile.writestr(cls.metadata_filename, meta_str)
 | 
			
		||||
                    if export_users:
 | 
			
		||||
                        cls._export_users(zfile)
 | 
			
		||||
                    artifacts = cls._export(
 | 
			
		||||
                        zfile,
 | 
			
		||||
                        entities=entities,
 | 
			
		||||
                        hash_=hash_,
 | 
			
		||||
                        tag_entities=tag_exported_entities,
 | 
			
		||||
                        export_events=export_events,
 | 
			
		||||
                        cleanup_users=not export_users,
 | 
			
		||||
                        url_trans=url_trans,
 | 
			
		||||
                    )
 | 
			
		||||
            except:
 | 
			
		||||
                temp_file.unlink(missing_ok=True)
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
            file_with_hash = file_base_name.with_stem(
 | 
			
		||||
                f"{file_base_name.stem}_{hash_.hexdigest()}"
 | 
			
		||||
            )
 | 
			
		||||
            temp_file.replace(file_with_hash)
 | 
			
		||||
            files = [str(file_with_hash)]
 | 
			
		||||
 | 
			
		||||
            artifacts = cls._filter_artifacts(artifacts)
 | 
			
		||||
            if artifacts and artifacts_path and os.path.isdir(artifacts_path):
 | 
			
		||||
                artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
 | 
			
		||||
                with ZipFile(artifacts_file, **cls.zip_args) as zfile:
 | 
			
		||||
                    cls._export_artifacts(zfile, artifacts, artifacts_path)
 | 
			
		||||
                files.append(str(artifacts_file))
 | 
			
		||||
 | 
			
		||||
            cls._write_update_file(
 | 
			
		||||
                map_file,
 | 
			
		||||
                entities=entities,
 | 
			
		||||
                created_files=files,
 | 
			
		||||
                metadata_hash=metadata_hash,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if files:
 | 
			
		||||
                _print("Created files:\n" + "\n".join(file for file in files))
 | 
			
		||||
 | 
			
		||||
            return files
 | 
			
		||||
 | 
			
		||||
        filename = Path(filename)
 | 
			
		||||
        if not (experiments or projects):
 | 
			
		||||
            query = dict(parent=None)
 | 
			
		||||
            if company:
 | 
			
		||||
                query["company"] = company
 | 
			
		||||
            projects = list(cls.project_cls.objects(**query).scalar("id"))
 | 
			
		||||
            # projects.append(None)
 | 
			
		||||
 | 
			
		||||
        if projects and project_split:
 | 
			
		||||
            created_files = list(
 | 
			
		||||
                chain.from_iterable(
 | 
			
		||||
                    export_to_zip_core(
 | 
			
		||||
                        file_base_name=filename.with_stem(f"{filename.stem}_{pid}"),
 | 
			
		||||
                        projects_=[pid],
 | 
			
		||||
                    )
 | 
			
		||||
                    for pid in projects
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            created_files = export_to_zip_core(
 | 
			
		||||
                file_base_name=filename, projects_=projects
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return created_files
 | 
			
		||||
 | 
			
		||||
@ -320,8 +361,10 @@ class PrePopulate:
 | 
			
		||||
                        meta_user_id = metadata.get("user_id", "")
 | 
			
		||||
                        meta_user_name = metadata.get("user_name", "")
 | 
			
		||||
                        user_id, user_name = meta_user_id, meta_user_name
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass
 | 
			
		||||
            except Exception as ex:
 | 
			
		||||
                _print(
 | 
			
		||||
                    f"Error getting metadata from {cls.metadata_filename}: {str(ex)}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Make sure we won't end up with an invalid company ID
 | 
			
		||||
            if company_id is None:
 | 
			
		||||
@ -347,7 +390,7 @@ class PrePopulate:
 | 
			
		||||
        if artifacts_path and os.path.isdir(artifacts_path):
 | 
			
		||||
            artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
 | 
			
		||||
            if artifacts_file.is_file():
 | 
			
		||||
                print(f"Unzipping artifacts into {artifacts_path}")
 | 
			
		||||
                _print(f"Unzipping artifacts into {artifacts_path}")
 | 
			
		||||
                with ZipFile(artifacts_file) as zfile:
 | 
			
		||||
                    zfile.extractall(artifacts_path)
 | 
			
		||||
 | 
			
		||||
@ -370,7 +413,7 @@ class PrePopulate:
 | 
			
		||||
        base_file_name, _, old_hash = file.stem.rpartition("_")
 | 
			
		||||
        new_hash = hash_.hexdigest()
 | 
			
		||||
        if old_hash == new_hash:
 | 
			
		||||
            print(f"The file {filename} was not updated")
 | 
			
		||||
            _print(f"The file {filename} was not updated")
 | 
			
		||||
            temp_file.unlink()
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
@ -384,7 +427,7 @@ class PrePopulate:
 | 
			
		||||
            artifacts_file.replace(new_artifacts)
 | 
			
		||||
            upadated.append(str(new_artifacts))
 | 
			
		||||
 | 
			
		||||
        print(f"File {str(file)} replaced with {str(new_file)}")
 | 
			
		||||
        _print(f"File {str(file)} replaced with {str(new_file)}")
 | 
			
		||||
        file.unlink()
 | 
			
		||||
 | 
			
		||||
        return upadated
 | 
			
		||||
@ -446,12 +489,12 @@ class PrePopulate:
 | 
			
		||||
 | 
			
		||||
        not_found = missing - set(resolved_by_name)
 | 
			
		||||
        if not_found:
 | 
			
		||||
            print(f"ERROR: no match for {', '.join(not_found)}")
 | 
			
		||||
            _print(f"ERROR: no match for {', '.join(not_found)}")
 | 
			
		||||
            exit(1)
 | 
			
		||||
 | 
			
		||||
        duplicates = [k for k, v in resolved_by_name.items() if len(v) > 1]
 | 
			
		||||
        if duplicates:
 | 
			
		||||
            print(f"ERROR: more than one match for {', '.join(duplicates)}")
 | 
			
		||||
            _print(f"ERROR: more than one match for {', '.join(duplicates)}")
 | 
			
		||||
            exit(1)
 | 
			
		||||
 | 
			
		||||
        def get_new_items(input_: Iterable) -> list:
 | 
			
		||||
@ -489,20 +532,24 @@ class PrePopulate:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        prefixes = [
 | 
			
		||||
            cls.ParentPrefix(prefix=f"{project.name.rpartition('/')[0]}/", path=project.path)
 | 
			
		||||
            cls.ParentPrefix(
 | 
			
		||||
                prefix=f"{project.name.rpartition('/')[0]}/", path=project.path
 | 
			
		||||
            )
 | 
			
		||||
            for project in orphans
 | 
			
		||||
        ]
 | 
			
		||||
        prefixes.sort(key=lambda p: len(p.path), reverse=True)
 | 
			
		||||
        for project in projects:
 | 
			
		||||
            prefix = first(pref for pref in prefixes if project.path[:len(pref.path)] == pref.path)
 | 
			
		||||
            prefix = first(
 | 
			
		||||
                pref for pref in prefixes if project.path[: len(pref.path)] == pref.path
 | 
			
		||||
            )
 | 
			
		||||
            if not prefix:
 | 
			
		||||
                continue
 | 
			
		||||
            project.path = project.path[len(prefix.path):]
 | 
			
		||||
            project.path = project.path[len(prefix.path) :]
 | 
			
		||||
            if not project.path:
 | 
			
		||||
                project.parent = None
 | 
			
		||||
            project.name = project.name.removeprefix(prefix.prefix)
 | 
			
		||||
 | 
			
		||||
        # print(
 | 
			
		||||
        # _print(
 | 
			
		||||
        #     f"ERROR: the following projects are exported without their parents: {orphans}"
 | 
			
		||||
        # )
 | 
			
		||||
        # exit(1)
 | 
			
		||||
@ -518,16 +565,20 @@ class PrePopulate:
 | 
			
		||||
        entities: Dict[Any] = defaultdict(set)
 | 
			
		||||
 | 
			
		||||
        if projects:
 | 
			
		||||
            print("Reading projects...")
 | 
			
		||||
            projects = project_ids_with_children(projects)
 | 
			
		||||
            entities[cls.project_cls].update(
 | 
			
		||||
                cls._resolve_entity_type(cls.project_cls, projects)
 | 
			
		||||
            )
 | 
			
		||||
            print("--> Reading project experiments...")
 | 
			
		||||
            _print("Reading projects...")
 | 
			
		||||
            root = None in projects
 | 
			
		||||
            projects = [p for p in projects if p]
 | 
			
		||||
            if projects:
 | 
			
		||||
                projects = project_ids_with_children(projects)
 | 
			
		||||
                entities[cls.project_cls].update(
 | 
			
		||||
                    cls._resolve_entity_type(cls.project_cls, projects)
 | 
			
		||||
                )
 | 
			
		||||
            _print("--> Reading project experiments...")
 | 
			
		||||
            p_ids = list(set(p.id for p in entities[cls.project_cls]))
 | 
			
		||||
            if root:
 | 
			
		||||
                p_ids.append(None)
 | 
			
		||||
            query = Q(
 | 
			
		||||
                project__in=list(
 | 
			
		||||
                    set(filter(None, (p.id for p in entities[cls.project_cls])))
 | 
			
		||||
                ),
 | 
			
		||||
                project__in=p_ids,
 | 
			
		||||
                system_tags__nin=[EntityVisibility.archived.value],
 | 
			
		||||
            )
 | 
			
		||||
            if task_statuses:
 | 
			
		||||
@ -538,9 +589,11 @@ class PrePopulate:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if experiments:
 | 
			
		||||
            print("Reading experiments...")
 | 
			
		||||
            entities[cls.task_cls].update(cls._resolve_entity_type(cls.task_cls, experiments))
 | 
			
		||||
            print("--> Reading experiments projects...")
 | 
			
		||||
            _print("Reading experiments...")
 | 
			
		||||
            entities[cls.task_cls].update(
 | 
			
		||||
                cls._resolve_entity_type(cls.task_cls, experiments)
 | 
			
		||||
            )
 | 
			
		||||
            _print("--> Reading experiments projects...")
 | 
			
		||||
            objs = cls.project_cls.objects(
 | 
			
		||||
                id__in=list(
 | 
			
		||||
                    set(filter(None, (p.project for p in entities[cls.task_cls])))
 | 
			
		||||
@ -560,7 +613,7 @@ class PrePopulate:
 | 
			
		||||
        )
 | 
			
		||||
        model_ids = {tm.model for tm in task_models}
 | 
			
		||||
        if model_ids:
 | 
			
		||||
            print("Reading models...")
 | 
			
		||||
            _print("Reading models...")
 | 
			
		||||
            entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
 | 
			
		||||
 | 
			
		||||
        # noinspection PyTypeChecker
 | 
			
		||||
@ -625,22 +678,41 @@ class PrePopulate:
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _translate_url(url_: str, url_trans: UrlTranslation) -> str:
 | 
			
		||||
        if not (url_ and url_trans):
 | 
			
		||||
            return url_
 | 
			
		||||
 | 
			
		||||
        source, target = url_trans
 | 
			
		||||
        if not url_.startswith(source):
 | 
			
		||||
            return url_
 | 
			
		||||
 | 
			
		||||
        return target + url_[len(source):]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _export_task_events(
 | 
			
		||||
        cls, task: Task, base_filename: str, writer: ZipFile, hash_
 | 
			
		||||
        cls,
 | 
			
		||||
        task: Task,
 | 
			
		||||
        base_filename: str,
 | 
			
		||||
        writer: ZipFile,
 | 
			
		||||
        hash_,
 | 
			
		||||
        url_trans: UrlTranslation,
 | 
			
		||||
    ) -> Sequence[str]:
 | 
			
		||||
        artifacts = []
 | 
			
		||||
        filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
 | 
			
		||||
        print(f"Writing task events into {writer.filename}:{filename}")
 | 
			
		||||
        _print(f"Writing task events into {writer.filename}:{filename}")
 | 
			
		||||
 | 
			
		||||
        with BytesIO() as f:
 | 
			
		||||
            with cls.JsonLinesWriter(f) as w:
 | 
			
		||||
                scroll_id = None
 | 
			
		||||
                events_count = 0
 | 
			
		||||
                while True:
 | 
			
		||||
                    res = cls.event_bll.get_task_events(
 | 
			
		||||
                        company_id=task.company,
 | 
			
		||||
                        task_id=task.id,
 | 
			
		||||
                        event_type=EventType.all,
 | 
			
		||||
                        scroll_id=scroll_id,
 | 
			
		||||
                        size=10_000,
 | 
			
		||||
                    )
 | 
			
		||||
                    if not res.events:
 | 
			
		||||
                        break
 | 
			
		||||
@ -650,16 +722,22 @@ class PrePopulate:
 | 
			
		||||
                        if event_type == EventType.metrics_image.value:
 | 
			
		||||
                            url = cls._get_fixed_url(event.get("url"))
 | 
			
		||||
                            if url:
 | 
			
		||||
                                event["url"] = url
 | 
			
		||||
                                artifacts.append(url)
 | 
			
		||||
                                event["url"] = cls._translate_url(url, url_trans)
 | 
			
		||||
                        elif event_type == EventType.metrics_plot.value:
 | 
			
		||||
                            plot_str: str = event.get("plot_str", "")
 | 
			
		||||
                            for match in cls.img_source_regex.findall(plot_str):
 | 
			
		||||
                                url = cls._get_fixed_url(match)
 | 
			
		||||
                                if match != url:
 | 
			
		||||
                                    plot_str = plot_str.replace(match, url)
 | 
			
		||||
                                artifacts.append(url)
 | 
			
		||||
                            if plot_str:
 | 
			
		||||
                                for match in cls.img_source_regex.findall(plot_str):
 | 
			
		||||
                                    url = cls._get_fixed_url(match)
 | 
			
		||||
                                    artifacts.append(url)
 | 
			
		||||
                                    new_url = cls._translate_url(url, url_trans)
 | 
			
		||||
                                    if match != new_url:
 | 
			
		||||
                                        plot_str = plot_str.replace(match, new_url)
 | 
			
		||||
                                event["plot_str"] = plot_str
 | 
			
		||||
                        w.write(json.dumps(event))
 | 
			
		||||
                        events_count += 1
 | 
			
		||||
                    _print(f"Got {events_count} events for task {task.id}")
 | 
			
		||||
            _print(f"Writing {events_count} events for task {task.id}")
 | 
			
		||||
            data = f.getvalue()
 | 
			
		||||
            hash_.update(data)
 | 
			
		||||
            writer.writestr(filename, data)
 | 
			
		||||
@ -677,53 +755,62 @@ class PrePopulate:
 | 
			
		||||
            fixed.host += ".s3.amazonaws.com"
 | 
			
		||||
            return fixed.url
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
            print(f"Failed processing link {url}. " + str(ex))
 | 
			
		||||
            _print(f"Failed processing link {url}. " + str(ex))
 | 
			
		||||
            return url
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _export_entity_related_data(
 | 
			
		||||
        cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
 | 
			
		||||
        cls,
 | 
			
		||||
        entity_cls,
 | 
			
		||||
        entity,
 | 
			
		||||
        base_filename: str,
 | 
			
		||||
        writer: ZipFile,
 | 
			
		||||
        hash_,
 | 
			
		||||
        url_trans: UrlTranslation,
 | 
			
		||||
    ):
 | 
			
		||||
        if entity_cls == cls.task_cls:
 | 
			
		||||
            return [
 | 
			
		||||
                *cls._get_task_output_artifacts(entity),
 | 
			
		||||
                *cls._export_task_events(entity, base_filename, writer, hash_),
 | 
			
		||||
                *cls._get_task_output_artifacts(entity, url_trans),
 | 
			
		||||
                *cls._export_task_events(
 | 
			
		||||
                    entity, base_filename, writer, hash_, url_trans
 | 
			
		||||
                ),
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        if entity_cls == cls.model_cls:
 | 
			
		||||
            entity.uri = cls._get_fixed_url(entity.uri)
 | 
			
		||||
            return [entity.uri] if entity.uri else []
 | 
			
		||||
            url = cls._get_fixed_url(entity.uri)
 | 
			
		||||
            entity.uri = cls._translate_url(url, url_trans)
 | 
			
		||||
            return [url] if url else []
 | 
			
		||||
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
 | 
			
		||||
    def _get_task_output_artifacts(cls, task: Task, url_trans: UrlTranslation) -> Sequence[str]:
 | 
			
		||||
        if not task.execution.artifacts:
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
        artifact_urls = []
 | 
			
		||||
        for a in task.execution.artifacts.values():
 | 
			
		||||
            if a.mode == ArtifactModes.output:
 | 
			
		||||
                a.uri = cls._get_fixed_url(a.uri)
 | 
			
		||||
                url = cls._get_fixed_url(a.uri)
 | 
			
		||||
                a.uri = cls._translate_url(url, url_trans)
 | 
			
		||||
                if url and a.mode == ArtifactModes.output:
 | 
			
		||||
                    artifact_urls.append(url)
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            a.uri
 | 
			
		||||
            for a in task.execution.artifacts.values()
 | 
			
		||||
            if a.mode == ArtifactModes.output and a.uri
 | 
			
		||||
        ]
 | 
			
		||||
        return artifact_urls
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _export_artifacts(
 | 
			
		||||
        cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
 | 
			
		||||
    ):
 | 
			
		||||
        unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
 | 
			
		||||
        print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
 | 
			
		||||
        _print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
 | 
			
		||||
        for path in unique_paths:
 | 
			
		||||
            path = path.lstrip("/")
 | 
			
		||||
            full_path = os.path.join(artifacts_path, path)
 | 
			
		||||
            if os.path.isfile(full_path):
 | 
			
		||||
                writer.write(full_path, path)
 | 
			
		||||
            else:
 | 
			
		||||
                print(f"Artifact {full_path} not found")
 | 
			
		||||
                _print(f"Artifact {full_path} not found")
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _export_users(cls, writer: ZipFile):
 | 
			
		||||
@ -742,7 +829,7 @@ class PrePopulate:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
 | 
			
		||||
        print(f"Writing {len(auth_users)} users into {writer.filename}")
 | 
			
		||||
        _print(f"Writing {len(auth_users)} users into {writer.filename}")
 | 
			
		||||
        data = {}
 | 
			
		||||
        for field, users in (("auth", auth_users), ("backend", be_users)):
 | 
			
		||||
            with BytesIO() as f:
 | 
			
		||||
@ -773,6 +860,7 @@ class PrePopulate:
 | 
			
		||||
        tag_entities: bool = False,
 | 
			
		||||
        export_events: bool = True,
 | 
			
		||||
        cleanup_users: bool = True,
 | 
			
		||||
        url_trans: UrlTranslation = None,
 | 
			
		||||
    ) -> Sequence[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Export the requested experiments, projects and models and return the list of artifact files
 | 
			
		||||
@ -780,7 +868,7 @@ class PrePopulate:
 | 
			
		||||
        The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom
 | 
			
		||||
        """
 | 
			
		||||
        artifacts = []
 | 
			
		||||
        now = datetime.utcnow()
 | 
			
		||||
        now = datetime.now(timezone.utc)
 | 
			
		||||
        for cls_ in sorted(entities, key=attrgetter("__name__")):
 | 
			
		||||
            items = sorted(entities[cls_], key=attrgetter("name", "id"))
 | 
			
		||||
            if not items:
 | 
			
		||||
@ -790,11 +878,11 @@ class PrePopulate:
 | 
			
		||||
                for item in items:
 | 
			
		||||
                    artifacts.extend(
 | 
			
		||||
                        cls._export_entity_related_data(
 | 
			
		||||
                            cls_, item, base_filename, writer, hash_
 | 
			
		||||
                            cls_, item, base_filename, writer, hash_, url_trans
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
            filename = base_filename + ".json"
 | 
			
		||||
            print(f"Writing {len(items)} items into {writer.filename}:{filename}")
 | 
			
		||||
            _print(f"Writing {len(items)} items into {writer.filename}:{filename}")
 | 
			
		||||
            with BytesIO() as f:
 | 
			
		||||
                with cls.JsonLinesWriter(f) as w:
 | 
			
		||||
                    for item in items:
 | 
			
		||||
@ -968,7 +1056,7 @@ class PrePopulate:
 | 
			
		||||
        for entity_file in entity_files:
 | 
			
		||||
            with reader.open(entity_file) as f:
 | 
			
		||||
                full_name = splitext(entity_file.orig_filename)[0]
 | 
			
		||||
                print(f"Reading {reader.filename}:{full_name}...")
 | 
			
		||||
                _print(f"Reading {reader.filename}:{full_name}...")
 | 
			
		||||
                res = cls._import_entity(
 | 
			
		||||
                    f,
 | 
			
		||||
                    full_name=full_name,
 | 
			
		||||
@ -996,7 +1084,7 @@ class PrePopulate:
 | 
			
		||||
                continue
 | 
			
		||||
            with reader.open(events_file) as f:
 | 
			
		||||
                full_name = splitext(events_file.orig_filename)[0]
 | 
			
		||||
                print(f"Reading {reader.filename}:{full_name}...")
 | 
			
		||||
                _print(f"Reading {reader.filename}:{full_name}...")
 | 
			
		||||
                cls._import_events(f, company_id, task.user, task.id)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
@ -1082,14 +1170,16 @@ class PrePopulate:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        models = task_data.get("models", {})
 | 
			
		||||
        now = datetime.utcnow()
 | 
			
		||||
        now = datetime.now(timezone.utc)
 | 
			
		||||
        for old_field, type_ in (
 | 
			
		||||
            ("execution.model", TaskModelTypes.input),
 | 
			
		||||
            ("output.model", TaskModelTypes.output),
 | 
			
		||||
        ):
 | 
			
		||||
            old_path = old_field.split(".")
 | 
			
		||||
            old_model = nested_get(task_data, old_path)
 | 
			
		||||
            new_models = [m for m in models.get(type_, []) if m.get("model") is not None]
 | 
			
		||||
            new_models = [
 | 
			
		||||
                m for m in models.get(type_, []) if m.get("model") is not None
 | 
			
		||||
            ]
 | 
			
		||||
            name = TaskModelNames[type_]
 | 
			
		||||
            if old_model and not any(
 | 
			
		||||
                m
 | 
			
		||||
@ -1127,7 +1217,7 @@ class PrePopulate:
 | 
			
		||||
    ) -> Optional[Sequence[Task]]:
 | 
			
		||||
        user_mapping = user_mapping or {}
 | 
			
		||||
        cls_ = cls._get_entity_type(full_name)
 | 
			
		||||
        print(f"Writing {cls_.__name__.lower()}s into database")
 | 
			
		||||
        _print(f"Writing {cls_.__name__.lower()}s into database")
 | 
			
		||||
        tasks = []
 | 
			
		||||
        override_project_count = 0
 | 
			
		||||
        data_upgrade_funcs: Mapping[Type, Callable] = {
 | 
			
		||||
@ -1164,21 +1254,23 @@ class PrePopulate:
 | 
			
		||||
                doc.logo_blob = metadata.get("logo_blob", None)
 | 
			
		||||
 | 
			
		||||
                cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
 | 
			
		||||
                    set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
 | 
			
		||||
                    set__name=f"{doc.name}_{datetime.now(timezone.utc).strftime('%Y-%m-%d_%H-%M-%S')}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            doc.save()
 | 
			
		||||
 | 
			
		||||
            if isinstance(doc, cls.task_cls):
 | 
			
		||||
                tasks.append(doc)
 | 
			
		||||
                cls.event_bll.delete_task_events(company_id, doc.id, wait_for_delete=True)
 | 
			
		||||
                cls.event_bll.delete_task_events(
 | 
			
		||||
                    company_id, doc.id, wait_for_delete=True
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if tasks:
 | 
			
		||||
            return tasks
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _import_events(cls, f: IO[bytes], company_id: str, user_id: str, task_id: str):
 | 
			
		||||
        print(f"Writing events for task {task_id} into database")
 | 
			
		||||
        _print(f"Writing events for task {task_id} into database")
 | 
			
		||||
        for events_chunk in chunked_iter(cls.json_lines(f), 1000):
 | 
			
		||||
            events = [json.loads(item) for item in events_chunk]
 | 
			
		||||
            for ev in events:
 | 
			
		||||
 | 
			
		||||
@ -32,8 +32,8 @@ class TestTasksArtifacts(TestService):
 | 
			
		||||
 | 
			
		||||
        # test edit
 | 
			
		||||
        artifacts = [
 | 
			
		||||
            dict(key="bb", type="str", uri="test1", mode="output"),
 | 
			
		||||
            dict(key="aa", type="int", uri="test2", mode="input"),
 | 
			
		||||
            dict(key="bb", type="str", uri="http://files.clear.ml/test1", mode="output"),
 | 
			
		||||
            dict(key="aa", type="int", uri="http://files.clear.ml/test2", mode="input"),
 | 
			
		||||
        ]
 | 
			
		||||
        self.api.tasks.edit(task=task, execution={"artifacts": artifacts})
 | 
			
		||||
        res = self.api.tasks.get_by_id(task=task).task
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,12 @@ class TestTaskPlots(TestService):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _create_task_event(task, iteration, **kwargs):
 | 
			
		||||
        plot_str = kwargs.get("plot_str")
 | 
			
		||||
        if plot_str:
 | 
			
		||||
            if not plot_str.startswith("http"):
 | 
			
		||||
                plot_str = "http://files.clear.ml/" + plot_str
 | 
			
		||||
            kwargs["plot_str"] = '{"source": "' + plot_str + '"}'
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "worker": "test",
 | 
			
		||||
            "type": "plot",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user