diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py index c251ad13..6b2c9713 100644 --- a/clearml/automation/controller.py +++ b/clearml/automation/controller.py @@ -98,6 +98,7 @@ class PipelineController(object): monitor_metrics = attrib(type=list, default=None) # List of metric title/series to monitor monitor_artifacts = attrib(type=list, default=None) # List of artifact names to monitor monitor_models = attrib(type=list, default=None) # List of models to monitor + explicit_docker_image = attrib(type=str, default=None) # The Docker image the node uses, specified at creation def __attrs_post_init__(self): if self.parents is None: @@ -2043,6 +2044,7 @@ class PipelineController(object): monitor_metrics=monitor_metrics, monitor_models=monitor_models, job_code_section=job_code_section, + explicit_docker_image=docker ) self._retries[name] = 0 self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \ @@ -2096,10 +2098,12 @@ class PipelineController(object): task_overrides = self._parse_task_overrides(node.task_overrides) if node.task_overrides else None extra_args = dict() - extra_args['project'] = self._get_target_project(return_project_id=True) or None + extra_args["project"] = self._get_target_project(return_project_id=True) or None # set Task name to match job name if self._pipeline_as_sub_project: - extra_args['name'] = node.name + extra_args["name"] = node.name + if node.explicit_docker_image: + extra_args["explicit_docker_image"] = node.explicit_docker_image skip_node = None if self._pre_step_callbacks.get(node.name): diff --git a/clearml/automation/job.py b/clearml/automation/job.py index a0547db1..3dc23643 100644 --- a/clearml/automation/job.py +++ b/clearml/automation/job.py @@ -378,8 +378,15 @@ class BaseJob(object): cls._hashing_callback = a_function @classmethod - def _create_task_hash(cls, task, section_overrides=None, params_override=None, configurations_override=None): - # type: (Task, Optional[dict], Optional[dict], Optional[dict]) -> Optional[str] + def _create_task_hash( + cls, + task, + section_overrides=None, + params_override=None, + configurations_override=None, + explicit_docker_image=None + ): + # type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str]) -> Optional[str] """ Create Hash (str) representing the state of the Task @@ -388,20 +395,22 @@ class BaseJob(object): :param params_override: Alternative to the entire Task's hyper parameters section (notice this should not be a nested dict but a flat key/value) :param configurations_override: dictionary of configuration override objects (tasks.ConfigurationItem) + :param explicit_docker_image: The explicit docker image. Used to invalidate the hash when the docker image + was explicitly changed :return: str hash of the Task configuration """ if not task: return None - if section_overrides and section_overrides.get('script'): - script = section_overrides['script'] + if section_overrides and section_overrides.get("script"): + script = section_overrides["script"] if not isinstance(script, dict): script = script.to_dict() else: script = task.data.script.to_dict() if task.data.script else {} # if we have a repository, we must make sure we have a specific version_num to ensure consistency - if script.get('repository') and not script.get('version_num') and not script.get('tag'): + if script.get("repository") and not script.get("version_num") and not script.get("tag"): return None # we need to ignore `requirements` section because ir might be changing from run to run @@ -412,17 +421,19 @@ class BaseJob(object): # currently we do not add the docker image to the hash (only args and setup script), # because default docker image will cause the step to change docker = None - if hasattr(task.data, 'container'): + if hasattr(task.data, "container"): docker = dict(**(task.data.container or dict())) - docker.pop('image', None) + docker.pop("image", None) + if explicit_docker_image: + docker["image"] = explicit_docker_image - hash_func = 'md5' if Session.check_min_api_version('2.13') else 'crc32' + hash_func = "md5" if Session.check_min_api_version("2.13") else "crc32" # make sure that if we only have docker args/bash, # we use encode it, otherwise we revert to the original encoding (excluding docker altogether) repr_dict = dict(script=script, hyper_params=hyper_params, configs=configs) if docker: - repr_dict['docker'] = docker + repr_dict["docker"] = docker # callback for modifying the representation dict if cls._hashing_callback: @@ -579,6 +590,7 @@ class ClearmlJob(BaseJob): section_overrides=sections, params_override=task_params, configurations_override=configuration_overrides or None, + explicit_docker_image=kwargs.get("explicit_docker_image") ) task = self._get_cached_task(task_hash) # if we found a task, just use