diff --git a/clearml/task.py b/clearml/task.py index ac65b785..529211eb 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -205,9 +205,13 @@ class Task(_Task): __main_task = None # type: Optional[Task] __exit_hook = None __forked_proc_main_pid = None - __task_id_reuse_time_window_in_hours = deferred_config("development.task_reuse_time_window_in_hours", 24.0, float) + __task_id_reuse_time_window_in_hours = deferred_config( + "development.task_reuse_time_window_in_hours", 24.0, float + ) __detect_repo_async = deferred_config("development.vcs_repo_detect_async", False) - __default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config("development.default_output_uri", None) + __default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config( + "development.default_output_uri", None + ) __hidden_tag = "hidden" @@ -217,8 +221,13 @@ class Task(_Task): _external_endpoint_port_map = {"http": "_PORT", "tcp": "external_tcp_port"} _external_endpoint_address_map = {"http": "_ADDRESS", "tcp": "external_address"} _external_endpoint_service_map = {"http": "EXTERNAL", "tcp": "EXTERNAL_TCP"} - _external_endpoint_internal_port_map = {"http": "_PORT", "tcp": "upstream_task_port"} - _external_endpoint_host_tcp_port_mapping = {"tcp_host_mapping": "_external_host_tcp_port_mapping"} + _external_endpoint_internal_port_map = { + "http": "_PORT", + "tcp": "upstream_task_port", + } + _external_endpoint_host_tcp_port_mapping = { + "tcp_host_mapping": "_external_host_tcp_port_mapping" + } class _ConnectedParametersType(object): argparse = "argument_parser" @@ -227,7 +236,11 @@ class Task(_Task): @classmethod def _options(cls): - return {var for var, val in vars(cls).items() if isinstance(val, six.string_types)} + return { + var + for var, val in vars(cls).items() + if isinstance(val, six.string_types) + } def __init__(self, private=None, **kwargs): """ @@ -501,7 +514,11 @@ class Task(_Task): validate = [ ("project name", project_name, cls.__main_task.get_project_name()), ("task name", task_name, cls.__main_task.name), - ("task type", str(task_type) if task_type else task_type, str(cls.__main_task.task_type)), + ( + "task type", + str(task_type) if task_type else task_type, + str(cls.__main_task.task_type), + ), ] for field, default, current in validate: @@ -517,7 +534,10 @@ class Task(_Task): ) ) - if cls.__main_task is not None and deferred_init != cls.__nested_deferred_init_flag: + if ( + cls.__main_task is not None + and deferred_init != cls.__nested_deferred_init_flag + ): # if this is a subprocess, regardless of what the init was called for, # we have to fix the main task hooks and stdout bindings if cls.__forked_proc_main_pid != os.getpid() and cls.__is_subprocess(): @@ -539,7 +559,9 @@ class Task(_Task): cls.__main_task._logger = None cls.__main_task.__reporter = None # noinspection PyProtectedMember - cls.__main_task._get_logger(auto_connect_streams=auto_connect_streams) + cls.__main_task._get_logger( + auto_connect_streams=auto_connect_streams + ) cls.__main_task._artifacts_manager = Artifacts(cls.__main_task) # unregister signal hooks, they cause subprocess to hang @@ -583,7 +605,9 @@ class Task(_Task): elif isinstance(task_type, six.string_types): if task_type not in Task.TaskTypes.__members__: raise ValueError( - "Task type '{}' not supported, options are: {}".format(task_type, Task.TaskTypes.__members__.keys()) + "Task type '{}' not supported, options are: {}".format( + task_type, Task.TaskTypes.__members__.keys() + ) ) task_type = Task.TaskTypes.__members__[str(task_type)] @@ -593,14 +617,20 @@ class Task(_Task): # check remote status _local_rank = get_torch_local_rank() if _local_rank is not None and _local_rank > 0: - is_sub_process_task_id = get_torch_distributed_anchor_task_id(timeout=30) + is_sub_process_task_id = get_torch_distributed_anchor_task_id( + timeout=30 + ) # only allow if running locally and creating the first Task # otherwise we ignore and perform in order if ENV_DEFERRED_TASK_INIT.get(): deferred_init = True - if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag: + if ( + not is_sub_process_task_id + and deferred_init + and deferred_init != cls.__nested_deferred_init_flag + ): def completed_cb(x): Task.__forked_proc_main_pid = os.getpid() @@ -637,12 +667,16 @@ class Task(_Task): tags=tags, reuse_last_task_id=reuse_last_task_id, continue_last_task=continue_last_task, - detect_repo=False - if ( - isinstance(auto_connect_frameworks, dict) - and not auto_connect_frameworks.get("detect_repository", True) - ) - else True, + detect_repo=( + False + if ( + isinstance(auto_connect_frameworks, dict) + and not auto_connect_frameworks.get( + "detect_repository", True + ) + ) + else True + ), auto_connect_streams=auto_connect_streams, ) # check if we are local rank 0 (local master), @@ -662,13 +696,20 @@ class Task(_Task): task.output_uri = None # create target data folder for logger / artifacts # noinspection PyProtectedMember - Path(task._get_default_report_storage_uri()).mkdir(parents=True, exist_ok=True) + Path(task._get_default_report_storage_uri()).mkdir( + parents=True, exist_ok=True + ) elif output_uri is not None: if output_uri is True: - output_uri = task.get_project_object().default_output_destination or True + output_uri = ( + task.get_project_object().default_output_destination + or True + ) task.output_uri = output_uri elif task.get_project_object().default_output_destination: - task.output_uri = task.get_project_object().default_output_destination + task.output_uri = ( + task.get_project_object().default_output_destination + ) elif cls.__default_output_uri: task.output_uri = str(cls.__default_output_uri) # store new task ID @@ -688,8 +729,13 @@ class Task(_Task): # Setting output_uri=False argument will disable using any default when running remotely pass else: - if task.get_project_object().default_output_destination and not task.output_uri: - task.output_uri = task.get_project_object().default_output_destination + if ( + task.get_project_object().default_output_destination + and not task.output_uri + ): + task.output_uri = ( + task.get_project_object().default_output_destination + ) if cls.__default_output_uri and not task.output_uri: task.output_uri = cls.__default_output_uri # store new task ID @@ -697,7 +743,9 @@ class Task(_Task): # make sure we are started task.started(ignore_errors=True) # continue last iteration if we had any (or we need to override it) - if isinstance(continue_last_task, int) and not isinstance(continue_last_task, bool): + if isinstance(continue_last_task, int) and not isinstance( + continue_last_task, bool + ): task.set_initial_iteration(int(continue_last_task)) elif task.data.last_iteration: task.set_initial_iteration(int(task.data.last_iteration) + 1) @@ -782,9 +830,15 @@ class Task(_Task): StdStreamPatch.patch_std_streams( task.get_logger(), connect_stdout=(auto_connect_streams is True) - or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stdout", False)), + or ( + isinstance(auto_connect_streams, dict) + and auto_connect_streams.get("stdout", False) + ), connect_stderr=(auto_connect_streams is True) - or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stderr", False)), + or ( + isinstance(auto_connect_streams, dict) + and auto_connect_streams.get("stderr", False) + ), load_config_defaults=False, ) return task # noqa @@ -796,28 +850,37 @@ class Task(_Task): else ResourceMonitor ) resource_monitor_kwargs = dict( - report_mem_used_per_process=not config.get("development.worker.report_global_mem_used", False), - first_report_sec=config.get("development.worker.report_start_sec", None), + report_mem_used_per_process=not config.get( + "development.worker.report_global_mem_used", False + ), + first_report_sec=config.get( + "development.worker.report_start_sec", None + ), wait_for_first_iteration_to_start_sec=config.get( "development.worker.wait_for_first_iteration_to_start_sec", None ), max_wait_for_first_iteration_to_start_sec=config.get( - "development.worker.max_wait_for_first_iteration_to_start_sec", None + "development.worker.max_wait_for_first_iteration_to_start_sec", + None, ), ) if isinstance(auto_resource_monitoring, dict): if "report_start_sec" in auto_resource_monitoring: - auto_resource_monitoring["first_report_sec"] = auto_resource_monitoring.pop("report_start_sec") + auto_resource_monitoring["first_report_sec"] = ( + auto_resource_monitoring.pop("report_start_sec") + ) if "seconds_from_start" in auto_resource_monitoring: - auto_resource_monitoring["first_report_sec"] = auto_resource_monitoring.pop( - "seconds_from_start" + auto_resource_monitoring["first_report_sec"] = ( + auto_resource_monitoring.pop("seconds_from_start") ) if "report_global_mem_used" in auto_resource_monitoring: - auto_resource_monitoring["report_mem_used_per_process"] = auto_resource_monitoring.pop( - "report_global_mem_used" + auto_resource_monitoring["report_mem_used_per_process"] = ( + auto_resource_monitoring.pop("report_global_mem_used") ) resource_monitor_kwargs.update(auto_resource_monitoring) - task._resource_monitor = resource_monitor_cls(task, **resource_monitor_kwargs) + task._resource_monitor = resource_monitor_cls( + task, **resource_monitor_kwargs + ) task._resource_monitor.start() # make sure all random generators are initialized with new seed @@ -861,10 +924,14 @@ class Task(_Task): if not is_sub_process_task_id: if cls._offline_mode: logger.report_text( - "ClearML running in offline mode, session stored in {}".format(task.get_offline_mode_folder()) + "ClearML running in offline mode, session stored in {}".format( + task.get_offline_mode_folder() + ) ) else: - logger.report_text("ClearML results page: {}".format(task.get_output_log_web_page())) + logger.report_text( + "ClearML results page: {}".format(task.get_output_log_web_page()) + ) # Make sure we start the dev worker if required, otherwise it will only be started when we write # something to the log. task._dev_mode_setup_worker() @@ -921,14 +988,21 @@ class Task(_Task): try: from .router.router import HttpRouter # noqa except ImportError: - raise UsageError("Could not import `HttpRouter`. Please run `pip install clearml[router]`") + raise UsageError( + "Could not import `HttpRouter`. Please run `pip install clearml[router]`" + ) if self._http_router is None: self._http_router = HttpRouter(self) return self._http_router def request_external_endpoint( - self, port, protocol="http", wait=False, wait_interval_seconds=3.0, wait_timeout_seconds=90.0 + self, + port, + protocol="http", + wait=False, + wait_interval_seconds=3.0, + wait_timeout_seconds=90.0, ): # type: (int, str, bool, float, float) -> Optional[Dict] """ @@ -956,7 +1030,9 @@ class Task(_Task): # sync with router - get data from Task if not self._external_endpoint_ports.get(protocol): self.reload() - internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol]) + internal_port = self._get_runtime_properties().get( + self._external_endpoint_internal_port_map[protocol] + ) if internal_port: self._external_endpoint_ports[protocol] = internal_port @@ -965,11 +1041,15 @@ class Task(_Task): external_host_port_mapping = self._get_runtime_properties().get( self._external_endpoint_host_tcp_port_mapping["tcp_host_mapping"] ) - self._external_endpoint_ports["tcp_host_mapping"] = external_host_port_mapping + self._external_endpoint_ports["tcp_host_mapping"] = ( + external_host_port_mapping + ) # check if we need to parse the port mapping, only if running on "bare-metal" host machine. if self._external_endpoint_ports.get("tcp_host_mapping"): - external_host_port_mapping = self._external_endpoint_ports.get("tcp_host_mapping") + external_host_port_mapping = self._external_endpoint_ports.get( + "tcp_host_mapping" + ) # format is docker standard port mapping format: # example: "out:in,out_range100-out_range102:in_range0-in_range2" # notice `out` in this context means the host port, the one that @@ -987,7 +1067,9 @@ class Task(_Task): out_port = int(out_range[0]) + (port - int(in_range[0])) print( "INFO: Task.request_external_endpoint(...) changed requested external port to {}, " - "conforming to mapped external host ports [{} -> {}]".format(out_port, port, port_range) + "conforming to mapped external host ports [{} -> {}]".format( + out_port, port, port_range + ) ) break @@ -997,7 +1079,9 @@ class Task(_Task): print( "WARNING: Task.request_external_endpoint(...) failed matching requested port to " "mapped external host port [{} to {}], " - "proceeding with original port {}".format(port, external_host_port_mapping, port) + "proceeding with original port {}".format( + port, external_host_port_mapping, port + ) ) # change the requested port to the one we have on the machine @@ -1012,7 +1096,9 @@ class Task(_Task): else: raise ValueError( "Only one endpoint per protocol can be requested at the moment. " - "Port already exposed is: {}".format(self._external_endpoint_ports.get(protocol)) + "Port already exposed is: {}".format( + self._external_endpoint_ports.get(protocol) + ) ) # mark for the router our request @@ -1020,12 +1106,15 @@ class Task(_Task): self._set_runtime_properties( { "_SERVICE": self._external_endpoint_service_map[protocol], - self._external_endpoint_address_map[protocol]: HOST_MACHINE_IP.get() or get_private_ip(), + self._external_endpoint_address_map[protocol]: HOST_MACHINE_IP.get() + or get_private_ip(), self._external_endpoint_port_map[protocol]: port, } ) # required system_tag for the router to catch the routing request - self.set_system_tags(list(set((self.get_system_tags() or []) + ["external_service"]))) + self.set_system_tags( + list(set((self.get_system_tags() or []) + ["external_service"])) + ) self._external_endpoint_ports[protocol] = port if wait: return self.wait_for_external_endpoint( @@ -1035,7 +1124,9 @@ class Task(_Task): ) return None - def wait_for_external_endpoint(self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http"): + def wait_for_external_endpoint( + self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http" + ): # type: (float, float, Optional[str]) -> Union[Optional[Dict], List[Optional[Dict]]] """ Wait for an external endpoint to be assigned @@ -1083,21 +1174,31 @@ class Task(_Task): unwaited_protocols = [p for p in protocols if p not in waited_protocols] if wait_timeout_seconds <= 0 and unwaited_protocols: LoggerRoot.get_base_logger().warning( - "Timeout exceeded while waiting for {} endpoint(s)".format(",".join(unwaited_protocols)) + "Timeout exceeded while waiting for {} endpoint(s)".format( + ",".join(unwaited_protocols) + ) ) return results def _wait_for_external_endpoint( - self, wait_interval_seconds=3.0, wait_timeout_seconds=90.0, protocol="http", warn=True + self, + wait_interval_seconds=3.0, + wait_timeout_seconds=90.0, + protocol="http", + warn=True, ): if not self._external_endpoint_ports.get(protocol): self.reload() - internal_port = self._get_runtime_properties().get(self._external_endpoint_internal_port_map[protocol]) + internal_port = self._get_runtime_properties().get( + self._external_endpoint_internal_port_map[protocol] + ) if internal_port: self._external_endpoint_ports[protocol] = internal_port if not self._external_endpoint_ports.get(protocol): if warn: - LoggerRoot.get_base_logger().warning("No external {} endpoints have been requested".format(protocol)) + LoggerRoot.get_base_logger().warning( + "No external {} endpoints have been requested".format(protocol) + ) return None start_time = time.time() while True: @@ -1113,7 +1214,11 @@ class Task(_Task): endpoint = ( runtime_props.get(self._external_endpoint_address_map[protocol]) + ":" - + str(runtime_props.get(self._external_endpoint_port_map[protocol])) + + str( + runtime_props.get( + self._external_endpoint_port_map[protocol] + ) + ) ) if endpoint or browser_endpoint: return { @@ -1125,7 +1230,9 @@ class Task(_Task): if time.time() >= start_time + wait_timeout_seconds: if warn: LoggerRoot.get_base_logger().warning( - "Timeout exceeded while waiting for {} endpoint".format(protocol) + "Timeout exceeded while waiting for {} endpoint".format( + protocol + ) ) return None time.sleep(wait_interval_seconds) @@ -1150,7 +1257,9 @@ class Task(_Task): results = [] protocols = [protocol] if protocol is not None else ["http", "tcp"] for protocol in protocols: - internal_port = runtime_props.get(self._external_endpoint_internal_port_map[protocol]) + internal_port = runtime_props.get( + self._external_endpoint_internal_port_map[protocol] + ) if internal_port: self._external_endpoint_ports[protocol] = internal_port else: @@ -1165,7 +1274,11 @@ class Task(_Task): endpoint = ( runtime_props.get(self._external_endpoint_address_map[protocol]) + ":" - + str(runtime_props.get(self._external_endpoint_port_map[protocol])) + + str( + runtime_props.get( + self._external_endpoint_port_map[protocol] + ) + ) ) if endpoint or browser_endpoint: results.append( @@ -1296,7 +1409,10 @@ class Task(_Task): :return: Task object of the most recent task with that name. """ - warnings.warn("Warning: 'Task.get_by_name' is deprecated. Use 'Task.get_task' instead", DeprecationWarning) + warnings.warn( + "Warning: 'Task.get_by_name' is deprecated. Use 'Task.get_task' instead", + DeprecationWarning, + ) return cls.get_task(task_name=task_name) @classmethod @@ -1467,10 +1583,16 @@ class Task(_Task): """ task_filter = task_filter or {} if not allow_archived: - task_filter["system_tags"] = (task_filter.get("system_tags") or []) + ["-{}".format(cls.archived_tag)] + task_filter["system_tags"] = (task_filter.get("system_tags") or []) + [ + "-{}".format(cls.archived_tag) + ] return cls.__get_tasks( - task_ids=task_ids, project_name=project_name, tags=tags, task_name=task_name, **task_filter + task_ids=task_ids, + project_name=project_name, + tags=tags, + task_name=task_name, + **task_filter ) @classmethod @@ -1549,18 +1671,24 @@ class Task(_Task): return_fields = {} if additional_return_fields: return_fields = set(list(additional_return_fields) + ["id"]) - task_filter["only_fields"] = (task_filter.get("only_fields") or []) + list(return_fields) + task_filter["only_fields"] = (task_filter.get("only_fields") or []) + list( + return_fields + ) if task_filter.get("type"): task_filter["type"] = [str(task_type) for task_type in task_filter["type"]] - results = cls._query_tasks(project_name=project_name, task_name=task_name, **task_filter) + results = cls._query_tasks( + project_name=project_name, task_name=task_name, **task_filter + ) return ( [t.id for t in results] if not additional_return_fields else [ { - k: cls._get_data_property(prop_path=k, data=r, raise_on_error=False, log_on_error=False) + k: cls._get_data_property( + prop_path=k, data=r, raise_on_error=False, log_on_error=False + ) for k in return_fields } for r in results @@ -1601,7 +1729,9 @@ class Task(_Task): if value is False: value = None elif value is True: - value = str(self.__default_output_uri or self._get_default_report_storage_uri()) + value = str( + self.__default_output_uri or self._get_default_report_storage_uri() + ) # check if we have the correct packages / configuration if value and value != self.storage_uri: @@ -1628,9 +1758,13 @@ class Task(_Task): return ReadOnlyDict() artifacts_pairs = [] if self.data.execution and self.data.execution.artifacts: - artifacts_pairs = [(a.key, Artifact(a)) for a in self.data.execution.artifacts] + artifacts_pairs = [ + (a.key, Artifact(a)) for a in self.data.execution.artifacts + ] if self._artifacts_manager: - artifacts_pairs += list(self._artifacts_manager.registered_artifacts.items()) + artifacts_pairs += list( + self._artifacts_manager.registered_artifacts.items() + ) return ReadOnlyDict(artifacts_pairs) @property @@ -1712,10 +1846,13 @@ class Task(_Task): assert isinstance(source_task, (six.string_types, Task)) if not Session.check_min_api_version("2.4"): raise ValueError( - "ClearML-server does not support DevOps features, " "upgrade clearml-server to 0.12.0 or above" + "ClearML-server does not support DevOps features, " + "upgrade clearml-server to 0.12.0 or above" ) - task_id = source_task if isinstance(source_task, six.string_types) else source_task.id + task_id = ( + source_task if isinstance(source_task, six.string_types) else source_task.id + ) if not parent: if isinstance(source_task, six.string_types): source_task = cls.get_task(task_id=source_task) @@ -1724,7 +1861,11 @@ class Task(_Task): parent = parent.id cloned_task_id = cls._clone_task( - cloned_task_id=task_id, name=name, comment=comment, parent=parent, project=project + cloned_task_id=task_id, + name=name, + comment=comment, + parent=parent, + project=project, ) cloned_task = cls.get_task(task_id=cloned_task_id) return cloned_task @@ -1776,7 +1917,8 @@ class Task(_Task): assert isinstance(task, (six.string_types, Task)) if not Session.check_min_api_version("2.4"): raise ValueError( - "ClearML-server does not support DevOps features, " "upgrade clearml-server to 0.12.0 or above" + "ClearML-server does not support DevOps features, " + "upgrade clearml-server to 0.12.0 or above" ) # make sure we have wither name ot id @@ -1824,7 +1966,9 @@ class Task(_Task): :return: The number of tasks enqueued in the given queue """ if not Session.check_min_api_server_version("2.20", raise_error=True): - raise ValueError("You version of clearml-server does not support the 'queues.get_num_entries' endpoint") + raise ValueError( + "You version of clearml-server does not support the 'queues.get_num_entries' endpoint" + ) mutually_exclusive(queue_name=queue_name, queue_id=queue_id) session = cls._get_default_session() if not queue_id: @@ -1833,7 +1977,11 @@ class Task(_Task): raise ValueError('Could not find queue named "{}"'.format(queue_name)) result = get_num_enqueued_tasks(session, queue_id) if result is None: - raise ValueError("Could not query the number of enqueued tasks in queue with ID {}".format(queue_id)) + raise ValueError( + "Could not query the number of enqueued tasks in queue with ID {}".format( + queue_id + ) + ) return result @classmethod @@ -1878,7 +2026,8 @@ class Task(_Task): assert isinstance(task, (six.string_types, Task)) if not Session.check_min_api_version("2.4"): raise ValueError( - "ClearML-server does not support DevOps features, " "upgrade clearml-server to 0.12.0 or above" + "ClearML-server does not support DevOps features, " + "upgrade clearml-server to 0.12.0 or above" ) task_id = task if isinstance(task, six.string_types) else task.id @@ -1897,7 +2046,11 @@ class Task(_Task): :param progress: numeric value (0 - 100) """ if not isinstance(progress, int) or progress < 0 or progress > 100: - self.log.warning("Can't set progress {} as it is not and int between 0 and 100".format(progress)) + self.log.warning( + "Can't set progress {} as it is not and int between 0 and 100".format( + progress + ) + ) return self._set_runtime_properties({"progress": str(progress)}) @@ -1963,7 +2116,8 @@ class Task(_Task): # input model connect and task parameters will handle this instead if not isinstance(mutable, (InputModel, TaskParameters)): ignore_remote_overrides = self._handle_ignore_remote_overrides( - (name or "General") + "/_ignore_remote_overrides_", ignore_remote_overrides + (name or "General") + "/_ignore_remote_overrides_", + ignore_remote_overrides, ) # dispatching by match order dispatch = ( @@ -1977,10 +2131,18 @@ class Task(_Task): ) multi_config_support = Session.check_min_api_version("2.9") - if multi_config_support and not name and not isinstance(mutable, (OutputModel, InputModel)): + if ( + multi_config_support + and not name + and not isinstance(mutable, (OutputModel, InputModel)) + ): name = self._default_configuration_section_name - if not multi_config_support and name and name != self._default_configuration_section_name: + if ( + not multi_config_support + and name + and name != self._default_configuration_section_name + ): raise ValueError( "Multiple configurations is not supported with the current 'clearml-server', " "please upgrade to the latest version" @@ -1988,9 +2150,14 @@ class Task(_Task): for mutable_type, method in dispatch: if isinstance(mutable, mutable_type): - return method(mutable, name=name, ignore_remote_overrides=ignore_remote_overrides) + return method( + mutable, name=name, ignore_remote_overrides=ignore_remote_overrides + ) - raise Exception("Unsupported mutable type %s: no connect function found" % type(mutable).__name__) + raise Exception( + "Unsupported mutable type %s: no connect function found" + % type(mutable).__name__ + ) def set_packages(self, packages): # type: (Union[str, Path, Sequence[str]]) -> () @@ -2069,7 +2236,9 @@ class Task(_Task): requirements_dict.update(self.data.script.requirements) return requirements_dict - def connect_configuration(self, configuration, name=None, description=None, ignore_remote_overrides=False): + def connect_configuration( + self, configuration, name=None, description=None, ignore_remote_overrides=False + ): # type: (Union[Mapping, list, Path, str], Optional[str], Optional[str], bool) -> Union[dict, Path, str] """ Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object. @@ -2112,7 +2281,8 @@ class Task(_Task): specified, then a path to a local configuration file is returned. Configuration object. """ ignore_remote_overrides = self._handle_ignore_remote_overrides( - (name or "General") + "/_ignore_remote_overrides_config_", ignore_remote_overrides + (name or "General") + "/_ignore_remote_overrides_config_", + ignore_remote_overrides, ) pathlib_Path = None # noqa cast_Path = Path @@ -2133,7 +2303,11 @@ class Task(_Task): if multi_config_support and not name: name = self._default_configuration_section_name - if not multi_config_support and name and name != self._default_configuration_section_name: + if ( + not multi_config_support + and name + and name != self._default_configuration_section_name + ): raise ValueError( "Multiple configurations is not supported with the current 'clearml-server', " "please upgrade to the latest version" @@ -2152,7 +2326,10 @@ class Task(_Task): if multi_config_support: # noinspection PyProtectedMember task._set_configuration( - name=name, description=description, config_type="dictionary", config_dict=config_dict + name=name, + description=description, + config_type="dictionary", + config_dict=config_dict, ) else: # noinspection PyProtectedMember @@ -2161,12 +2338,17 @@ class Task(_Task): def get_dev_config(configuration_): if multi_config_support: self._set_configuration( - name=name, description=description, config_type="dictionary", config_dict=configuration_ + name=name, + description=description, + config_type="dictionary", + config_dict=configuration_, ) else: self._set_model_config(config_dict=configuration) if isinstance(configuration_, dict): - configuration_ = ProxyDictPostWrite(self, _update_config_dict, configuration_) + configuration_ = ProxyDictPostWrite( + self, _update_config_dict, configuration_ + ) return configuration_ if ( @@ -2189,12 +2371,17 @@ class Task(_Task): if remote_configuration is None: LoggerRoot.get_base_logger().warning( "Could not retrieve remote configuration named '{}'\n" - "Using default configuration: {}".format(name, str(configuration)) + "Using default configuration: {}".format( + name, str(configuration) + ) ) # update back configuration section if multi_config_support: self._set_configuration( - name=name, description=description, config_type="dictionary", config_dict=configuration + name=name, + description=description, + config_type="dictionary", + config_dict=configuration, ) return configuration @@ -2233,9 +2420,12 @@ class Task(_Task): self._set_configuration( name=name, description=description, - config_type=configuration_path.suffixes[-1].lstrip(".") - if configuration_path.suffixes and configuration_path.suffixes[-1] - else "file", + config_type=( + configuration_path.suffixes[-1].lstrip(".") + if configuration_path.suffixes + and configuration_path.suffixes[-1] + else "file" + ), config_text=configuration_text, ) else: @@ -2243,7 +2433,9 @@ class Task(_Task): return configuration else: configuration_text = ( - self._get_configuration_text(name=name) if multi_config_support else self._get_model_config_text() + self._get_configuration_text(name=name) + if multi_config_support + else self._get_model_config_text() ) if configuration_text is None: LoggerRoot.get_base_logger().warning( @@ -2260,9 +2452,12 @@ class Task(_Task): self._set_configuration( name=name, description=description, - config_type=configuration_path.suffixes[-1].lstrip(".") - if configuration_path.suffixes and configuration_path.suffixes[-1] - else "file", + config_type=( + configuration_path.suffixes[-1].lstrip(".") + if configuration_path.suffixes + and configuration_path.suffixes[-1] + else "file" + ), config_text=configuration_text, ) return configuration @@ -2270,11 +2465,19 @@ class Task(_Task): configuration_path = cast_Path(configuration) fd, local_filename = mkstemp( prefix="clearml_task_config_", - suffix=configuration_path.suffixes[-1] if configuration_path.suffixes else ".txt", + suffix=( + configuration_path.suffixes[-1] + if configuration_path.suffixes + else ".txt" + ), ) with open(fd, "w") as f: f.write(configuration_text) - return cast_Path(local_filename) if isinstance(configuration, cast_Path) else local_filename + return ( + cast_Path(local_filename) + if isinstance(configuration, cast_Path) + else local_filename + ) def connect_label_enumeration(self, enumeration, ignore_remote_overrides=False): # type: (Dict[str, int], bool) -> Dict[str, int] @@ -2299,11 +2502,13 @@ class Task(_Task): :return: The label enumeration dictionary (JSON). """ ignore_remote_overrides = self._handle_ignore_remote_overrides( - "General/_ignore_remote_overrides_label_enumeration_", ignore_remote_overrides + "General/_ignore_remote_overrides_label_enumeration_", + ignore_remote_overrides, ) if not isinstance(enumeration, dict): raise ValueError( - "connect_label_enumeration supports only `dict` type, " "{} is not supported".format(type(enumeration)) + "connect_label_enumeration supports only `dict` type, " + "{} is not supported".format(type(enumeration)) ) if ( @@ -2429,45 +2634,66 @@ class Task(_Task): def set_launch_multi_node_runtime_props(task, conf): # noinspection PyProtectedMember task._set_runtime_properties( - {"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()} + { + "{}/{}".format(self._launch_multi_node_section, k): v + for k, v in conf.items() + } ) if total_num_nodes < 1: raise UsageError("total_num_nodes needs to be at least 1") - if running_remotely() and not (self.data.execution and self.data.execution.queue) and not queue: - raise UsageError("Master task is not enqueued to any queue and the queue parameter is None") + if ( + running_remotely() + and not (self.data.execution and self.data.execution.queue) + and not queue + ): + raise UsageError( + "Master task is not enqueued to any queue and the queue parameter is None" + ) master_conf = { "master_addr": os.environ.get( - "CLEARML_MULTI_NODE_MASTER_DEF_ADDR", os.environ.get("MASTER_ADDR", addr or get_private_ip()) + "CLEARML_MULTI_NODE_MASTER_DEF_ADDR", + os.environ.get("MASTER_ADDR", addr or get_private_ip()), ), "master_port": int( - os.environ.get("CLEARML_MULTI_NODE_MASTER_DEF_PORT", os.environ.get("MASTER_PORT", port)) + os.environ.get( + "CLEARML_MULTI_NODE_MASTER_DEF_PORT", + os.environ.get("MASTER_PORT", port), + ) ), "node_rank": 0, "wait": wait, "devices": devices, } editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue} - editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section) + editable_conf = self.connect( + editable_conf, name=self._launch_multi_node_section + ) if not running_remotely(): return master_conf master_conf.update(editable_conf) runtime_properties = self._get_runtime_properties() - remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section)) + remote_node_rank = runtime_properties.get( + "{}/node_rank".format(self._launch_multi_node_section) + ) current_conf = master_conf if remote_node_rank: # self is a child node, build the conf from the runtime proprerties current_conf = { - entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry)) + entry: runtime_properties.get( + "{}/{}".format(self._launch_multi_node_section, entry) + ) for entry in master_conf.keys() } elif os.environ.get("CLEARML_MULTI_NODE_MASTER") is None: nodes_to_wait = [] # self is the master node, enqueue the other nodes set_launch_multi_node_runtime_props(self, master_conf) - for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)): + for node_rank in range( + 1, master_conf.get("total_num_nodes", total_num_nodes) + ): node = self.clone(source_task=self, parent=self.id) node_conf = copy.deepcopy(master_conf) node_conf["node_rank"] = node_rank @@ -2483,8 +2709,15 @@ class Task(_Task): Task.enqueue(node, queue_id=self.data.execution.queue) if master_conf.get("wait"): nodes_to_wait.append(node) - for node_to_wait, rank in zip(nodes_to_wait, range(1, master_conf.get("total_num_nodes", total_num_nodes))): - self.log.info("Waiting for node with task ID {} and rank {}".format(node_to_wait.id, rank)) + for node_to_wait, rank in zip( + nodes_to_wait, + range(1, master_conf.get("total_num_nodes", total_num_nodes)), + ): + self.log.info( + "Waiting for node with task ID {} and rank {}".format( + node_to_wait.id, rank + ) + ) node_to_wait.wait_for_status( status=( Task.TaskStatusEnum.completed, @@ -2495,7 +2728,11 @@ class Task(_Task): ), check_interval_sec=10, ) - self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank)) + self.log.info( + "Node with task ID {} and rank {} detected".format( + node_to_wait.id, rank + ) + ) os.environ["CLEARML_MULTI_NODE_MASTER"] = "1" num_devices = 1 @@ -2523,10 +2760,13 @@ class Task(_Task): os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "") os.environ["MASTER_PORT"] = str(current_conf.get("master_port", "")) os.environ["RANK"] = str( - current_conf.get("node_rank", 0) * num_devices + int(os.environ.get("LOCAL_RANK", "0")) + current_conf.get("node_rank", 0) * num_devices + + int(os.environ.get("LOCAL_RANK", "0")) ) os.environ["NODE_RANK"] = str(current_conf.get("node_rank", "")) - os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", total_num_nodes) * num_devices) + os.environ["WORLD_SIZE"] = str( + current_conf.get("total_num_nodes", total_num_nodes) * num_devices + ) return current_conf @@ -2553,7 +2793,9 @@ class Task(_Task): # flush any outstanding logs self.flush(wait_for_uploads=True) # mark task as stopped - self.stopped(force=force, status_message=str(status_message) if status_message else None) + self.stopped( + force=force, status_message=str(status_message) if status_message else None + ) def mark_stop_request(self, force=False, status_message=None): # type: (bool, Optional[str]) -> () @@ -2621,7 +2863,9 @@ class Task(_Task): - ``False`` - Do not force (default) """ if not running_remotely() or not self.is_main_task() or force: - super(Task, self).reset(set_started_on_success=set_started_on_success, force=force) + super(Task, self).reset( + set_started_on_success=set_started_on_success, force=force + ) def close(self): """ @@ -2738,12 +2982,18 @@ class Task(_Task): value of ``True``. If ``True``, the artifact uniqueness comparison criteria is all the columns, which is the same as ``artifact.columns``. """ - if not isinstance(uniqueness_columns, CollectionsSequence) and uniqueness_columns is not True: + if ( + not isinstance(uniqueness_columns, CollectionsSequence) + and uniqueness_columns is not True + ): raise ValueError("uniqueness_columns should be a List (sequence) or True") if isinstance(uniqueness_columns, str): uniqueness_columns = [uniqueness_columns] self._artifacts_manager.register_artifact( - name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns + name=name, + artifact=artifact, + metadata=metadata, + uniqueness_columns=uniqueness_columns, ) def unregister_artifact(self, name): @@ -2868,7 +3118,9 @@ class Task(_Task): exception_to_raise = e if retry < retries: getLogger().warning( - "Failed uploading artifact '{}'. Retrying... ({}/{})".format(name, retry + 1, retries) + "Failed uploading artifact '{}'. Retrying... ({}/{})".format( + name, retry + 1, retries + ) ) if exception_to_raise: raise exception_to_raise @@ -2894,7 +3146,9 @@ class Task(_Task): n_last_iterations = MAX_SERIES_PER_METRIC.get() if isinstance(n_last_iterations, int) and n_last_iterations >= 0: - samples = self._get_debug_samples(title=title, series=series, n_last_iterations=n_last_iterations) + samples = self._get_debug_samples( + title=title, series=series, n_last_iterations=n_last_iterations + ) else: raise TypeError( "Parameter n_last_iterations is expected to be a positive integer value," @@ -2903,7 +3157,9 @@ class Task(_Task): return samples - def _send_debug_image_request(self, title, series, n_last_iterations, scroll_id=None): + def _send_debug_image_request( + self, title, series, n_last_iterations, scroll_id=None + ): return Task._send( Task._get_default_session(), events.DebugImagesRequest( @@ -2923,15 +3179,23 @@ class Task(_Task): for metric_resp in response.response_data.get("metrics", []): iterations_events = [ - iteration["events"] for iteration in metric_resp.get("iterations", []) + iteration["events"] + for iteration in metric_resp.get("iterations", []) ] # type: List[List[dict]] - flattened_events = (event for single_iter_events in iterations_events for event in single_iter_events) + flattened_events = ( + event + for single_iter_events in iterations_events + for event in single_iter_events + ) debug_samples.extend(flattened_events) - response = self._send_debug_image_request(title, series, n_last_iterations, scroll_id=scroll_id) + response = self._send_debug_image_request( + title, series, n_last_iterations, scroll_id=scroll_id + ) if len(debug_samples) == n_last_iterations or all( - len(metric_resp.get("iterations", [])) == 0 for metric_resp in response.response_data.get("metrics", []) + len(metric_resp.get("iterations", [])) == 0 + for metric_resp in response.response_data.get("metrics", []) ): break @@ -3050,7 +3314,10 @@ class Task(_Task): :return: The last reported iteration number. """ self._reload_last_iteration() - return max(self.data.last_iteration or 0, self.__reporter.max_iteration if self.__reporter else 0) + return max( + self.data.last_iteration or 0, + self.__reporter.max_iteration if self.__reporter else 0, + ) def set_initial_iteration(self, offset=0): # type: (int) -> int @@ -3099,7 +3366,8 @@ class Task(_Task): for i in metrics.values(): for j in i.values(): scalar_metrics.setdefault(j["metric"], {}).setdefault( - j["variant"], {"last": j["value"], "min": j["min_value"], "max": j["max_value"]} + j["variant"], + {"last": j["value"], "min": j["min_value"], "max": j["max_value"]}, ) return scalar_metrics @@ -3141,7 +3409,8 @@ class Task(_Task): section = "properties" params = self._hyper_params_manager.get_hyper_params( - sections=[section], projector=(lambda x: x.get("value")) if value_only else None + sections=[section], + projector=(lambda x: x.get("value")) if value_only else None, ) return dict(params.get(section, {})) @@ -3222,7 +3491,11 @@ class Task(_Task): return self._hyper_params_manager.edit_hyper_params( iterables=list(properties.items()) - + (list(iterables.items()) if isinstance(iterables, dict) else list(iterables)), + + ( + list(iterables.items()) + if isinstance(iterables, dict) + else list(iterables) + ), replace="none", force_section="properties", ) @@ -3385,14 +3658,20 @@ class Task(_Task): # noinspection PyProtectedMember instance._first_report_sec = seconds_from_start instance.wait_for_first_iteration = wait_for_first_iteration_to_start_sec - instance.max_check_first_iteration = max_wait_for_first_iteration_to_start_sec + instance.max_check_first_iteration = ( + max_wait_for_first_iteration_to_start_sec + ) # noinspection PyProtectedMember ResourceMonitor._first_report_sec_default = seconds_from_start # noinspection PyProtectedMember - ResourceMonitor._wait_for_first_iteration_to_start_sec_default = wait_for_first_iteration_to_start_sec + ResourceMonitor._wait_for_first_iteration_to_start_sec_default = ( + wait_for_first_iteration_to_start_sec + ) # noinspection PyProtectedMember - ResourceMonitor._max_wait_for_first_iteration_to_start_sec_default = max_wait_for_first_iteration_to_start_sec + ResourceMonitor._max_wait_for_first_iteration_to_start_sec_default = ( + max_wait_for_first_iteration_to_start_sec + ) return True def execute_remotely(self, queue_name=None, clone=False, exit_process=True): @@ -3468,11 +3747,15 @@ class Task(_Task): if queue_name: Task.enqueue(task, queue_name=queue_name) LoggerRoot.get_base_logger().warning( - "Switching to remote execution, output log page {}".format(task.get_output_log_web_page()) + "Switching to remote execution, output log page {}".format( + task.get_output_log_web_page() + ) ) else: # Remove the development system tag - system_tags = [t for t in task.get_system_tags() if t != self._development_tag] + system_tags = [ + t for t in task.get_system_tags() if t != self._development_tag + ] self.set_system_tags(system_tags) # if we leave the Task out there, it makes sense to make it editable. self.reset(force=True) @@ -3513,12 +3796,15 @@ class Task(_Task): :return Task: Return the newly created Task or None if running remotely and execution is skipped """ if not self.is_main_task(): - raise ValueError("Only the main Task object can call create_function_task()") + raise ValueError( + "Only the main Task object can call create_function_task()" + ) if not callable(func): raise ValueError("func must be callable") if not Session.check_min_api_version("2.9"): raise ValueError( - "Remote function execution is not supported, " "please upgrade to the latest server version" + "Remote function execution is not supported, " + "please upgrade to the latest server version" ) func_name = str(func_name or func.__name__).strip() @@ -3537,7 +3823,9 @@ class Task(_Task): func_params[func_marker] = func_name # do not query if we are running locally, there is no need. - task_func_marker = self.running_locally() or self.get_parameter("{}/{}".format(section_name, func_marker)) + task_func_marker = self.running_locally() or self.get_parameter( + "{}/{}".format(section_name, func_marker) + ) # if we are running locally or if we are running remotely but we are not a forked tasks # condition explained: @@ -3545,7 +3833,11 @@ class Task(_Task): # (2) running remotely but this is not one of the forked tasks (i.e. it is missing the fork tag attribute) if self.running_locally() or not task_func_marker: self._wait_for_repo_detection(300) - task = self.clone(self, name=task_name or "{} <{}>".format(self.name, func_name), parent=self.id) + task = self.clone( + self, + name=task_name or "{} <{}>".format(self.name, func_name), + parent=self.id, + ) task.set_system_tags((task.get_system_tags() or []) + [tag_name]) task.connect(func_params, name=section_name) self._remote_functions_generated[func_name] = task.id @@ -3554,7 +3846,9 @@ class Task(_Task): # check if we are one of the generated functions and if this is us, # if we are not the correct function, not do nothing and leave if task_func_marker != func_name: - self._remote_functions_generated[func_name] = len(self._remote_functions_generated) + 1 + self._remote_functions_generated[func_name] = ( + len(self._remote_functions_generated) + 1 + ) return # mark this is us: @@ -3570,7 +3864,11 @@ class Task(_Task): def wait_for_status( self, - status=(_Task.TaskStatusEnum.completed, _Task.TaskStatusEnum.stopped, _Task.TaskStatusEnum.closed), + status=( + _Task.TaskStatusEnum.completed, + _Task.TaskStatusEnum.stopped, + _Task.TaskStatusEnum.closed, + ), raise_on_status=(_Task.TaskStatusEnum.failed,), check_interval_sec=60.0, ): @@ -3585,12 +3883,16 @@ class Task(_Task): :raise: RuntimeError if the status is one of ``{raise_on_status}``. """ - stopped_status = list(status) + (list(raise_on_status) if raise_on_status else []) + stopped_status = list(status) + ( + list(raise_on_status) if raise_on_status else [] + ) while self.status not in stopped_status: time.sleep(check_interval_sec) if raise_on_status and self.status in raise_on_status: - raise RuntimeError("Task {} has status: {}.".format(self.task_id, self.status)) + raise RuntimeError( + "Task {} has status: {}.".format(self.task_id, self.status) + ) # make sure we have the Task object self.reload() @@ -3627,7 +3929,9 @@ class Task(_Task): :param task_data: dictionary with full Task configuration :return: return True if Task update was successful """ - return bool(self.import_task(task_data=task_data, target_task=self, update=True)) + return bool( + self.import_task(task_data=task_data, target_task=self, update=True) + ) def rename(self, new_name): # type: (str) -> bool @@ -3642,7 +3946,9 @@ class Task(_Task): self.reload() return result - def move_to_project(self, new_project_id=None, new_project_name=None, system_tags=None): + def move_to_project( + self, new_project_id=None, new_project_name=None, system_tags=None + ): # type: (Optional[str], Optional[str], Optional[Sequence[str]]) -> bool """ Move this task to another project @@ -3656,7 +3962,10 @@ class Task(_Task): :return: True if the move was successful and False otherwise """ new_project_id = get_or_create_project( - self.session, project_name=new_project_name, project_id=new_project_id, system_tags=system_tags + self.session, + project_name=new_project_name, + project_id=new_project_id, + system_tags=system_tags, ) result = bool(self._edit(project=new_project_id)) self.reload() @@ -3683,16 +3992,22 @@ class Task(_Task): will be terminated even if the callback did not return """ if self.__is_subprocess(): - raise ValueError("Register abort callback must be called from the main process, this is a subprocess.") + raise ValueError( + "Register abort callback must be called from the main process, this is a subprocess." + ) if callback_function is None: if self._dev_worker: - self._dev_worker.register_abort_callback(callback_function=None, execution_timeout=0, poll_freq=0) + self._dev_worker.register_abort_callback( + callback_function=None, execution_timeout=0, poll_freq=0 + ) return if float(callback_execution_timeout) <= 0: raise ValueError( - "function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout) + "function_timeout_sec must be positive timeout in seconds, got {}".format( + callback_execution_timeout + ) ) # if we are running remotely we might not have a DevWorker monitoring us, so let's create one @@ -3702,7 +4017,9 @@ class Task(_Task): poll_freq = 15.0 self._dev_worker.register_abort_callback( - callback_function=callback_function, execution_timeout=callback_execution_timeout, poll_freq=poll_freq + callback_function=callback_function, + execution_timeout=callback_execution_timeout, + poll_freq=poll_freq, ) @classmethod @@ -3726,8 +4043,12 @@ class Task(_Task): Session.force_max_api_version = str(force_api_version) if not target_task: - project_name = task_data.get("project_name") or Task._get_project_name(task_data.get("project", "")) - target_task = Task.create(project_name=project_name, task_name=task_data.get("name", None)) + project_name = task_data.get("project_name") or Task._get_project_name( + task_data.get("project", "") + ) + target_task = Task.create( + project_name=project_name, task_name=task_data.get("name", None) + ) elif isinstance(target_task, six.string_types): target_task = Task.get_task(task_id=target_task) # type: Optional[Task] elif not isinstance(target_task, Task): @@ -3803,7 +4124,11 @@ class Task(_Task): """ if running_remotely() or bool(offline_mode) == InterfaceBase._offline_mode: return - if cls.current_task() and cls.current_task().status != cls.TaskStatusEnum.closed and not offline_mode: + if ( + cls.current_task() + and cls.current_task().status != cls.TaskStatusEnum.closed + and not offline_mode + ): raise UsageError( "Switching from offline mode to online mode, but the current task has not been closed. Use `Task.close` to close it." ) @@ -3825,7 +4150,9 @@ class Task(_Task): return cls._offline_mode @classmethod - def import_offline_session(cls, session_folder_zip, previous_task_id=None, iteration_offset=0): + def import_offline_session( + cls, session_folder_zip, previous_task_id=None, iteration_offset=0 + ): # type: (str, Optional[str], Optional[int]) -> (Optional[str]) """ Upload an offline session (execution) of a Task. @@ -3850,20 +4177,26 @@ class Task(_Task): session_folder = Path(session_folder_zip) if not session_folder.is_dir(): - raise ValueError("Could not find the session folder / zip-file {}".format(session_folder)) + raise ValueError( + "Could not find the session folder / zip-file {}".format(session_folder) + ) try: with open((session_folder / cls._offline_filename).as_posix(), "rt") as f: export_data = json.load(f) except Exception as ex: raise ValueError( - "Could not read Task object {}: Exception {}".format(session_folder / cls._offline_filename, ex) + "Could not read Task object {}: Exception {}".format( + session_folder / cls._offline_filename, ex + ) ) current_task = cls.import_task(export_data) if previous_task_id: task_holding_reports = cls.get_task(task_id=previous_task_id) task_holding_reports.mark_started(force=True) - task_holding_reports = cls.import_task(export_data, target_task=task_holding_reports, update=True) + task_holding_reports = cls.import_task( + export_data, target_task=task_holding_reports, update=True + ) else: task_holding_reports = current_task task_holding_reports.mark_started(force=True) @@ -3872,7 +4205,9 @@ class Task(_Task): from . import StorageManager # noinspection PyProtectedMember - offline_folder = os.path.join(export_data.get("offline_folder", ""), "data/") + offline_folder = os.path.join( + export_data.get("offline_folder", ""), "data/" + ) # noinspection PyProtectedMember remote_url = current_task._get_default_report_storage_uri() @@ -3884,10 +4219,14 @@ class Task(_Task): local_file = session_folder / "data" / local_path if local_file.is_file(): remote_path = local_path.replace( - ".{}{}".format(export_data["id"], os.sep), ".{}{}".format(current_task.id, os.sep), 1 + ".{}{}".format(export_data["id"], os.sep), + ".{}{}".format(current_task.id, os.sep), + 1, ) artifact.uri = "{}/{}".format(remote_url, remote_path) - StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri) + StorageManager.upload_file( + local_file=local_file.as_posix(), remote_url=artifact.uri + ) # noinspection PyProtectedMember task_holding_reports._edit(execution=current_task.data.execution) for output_model in export_data.get("offline_output_models", []): @@ -3904,7 +4243,9 @@ class Task(_Task): session=task_holding_reports.session, ) # logs - TaskHandler.report_offline_session(task_holding_reports, session_folder, iteration_offset=iteration_offset) + TaskHandler.report_offline_session( + task_holding_reports, session_folder, iteration_offset=iteration_offset + ) # metrics Metrics.report_offline_session( task_holding_reports, @@ -3914,7 +4255,11 @@ class Task(_Task): session=task_holding_reports.session, ) # print imported results page - print("ClearML results page: {}".format(task_holding_reports.get_output_log_web_page())) + print( + "ClearML results page: {}".format( + task_holding_reports.get_output_log_web_page() + ) + ) task_holding_reports.mark_completed() # close task task_holding_reports.close() @@ -3931,7 +4276,13 @@ class Task(_Task): @classmethod def set_credentials( - cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, store_conf_file=False + cls, + api_host=None, + web_host=None, + files_host=None, + key=None, + secret=None, + store_conf_file=False, ): # type: (Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], bool) -> None """ @@ -3987,7 +4338,8 @@ class Task(_Task): active_conf_file = get_active_config_file() if active_conf_file: getLogger().warning( - "Could not store credentials in configuration file, " "'{}' already exists".format(active_conf_file) + "Could not store credentials in configuration file, " + "'{}' already exists".format(active_conf_file) ) else: conf = { @@ -3995,7 +4347,10 @@ class Task(_Task): api_server=Session.default_host, web_server=Session.default_web, files_server=Session.default_files, - credentials=dict(access_key=Session.default_key, secret_key=Session.default_secret), + credentials=dict( + access_key=Session.default_key, + secret_key=Session.default_secret, + ), ) } with open(get_config_file(), "wt") as f: @@ -4050,10 +4405,14 @@ class Task(_Task): if not return_name or not queue_id: return queue_id try: - queue_name_result = Task._send(Task._get_default_session(), queues.GetByIdRequest(queue_id)) + queue_name_result = Task._send( + Task._get_default_session(), queues.GetByIdRequest(queue_id) + ) return queue_name_result.response.queue.name except Exception as e: - getLogger().warning("Could not get name of queue with ID '{}': {}".format(queue_id, e)) + getLogger().warning( + "Could not get name of queue with ID '{}': {}".format(queue_id, e) + ) return None @classmethod @@ -4104,7 +4463,9 @@ class Task(_Task): If `config_dict` is not None, `config_text` must not be provided. """ # noinspection PyProtectedMember - design = OutputModel._resolve_config(config_text=config_text, config_dict=config_dict) + design = OutputModel._resolve_config( + config_text=config_text, config_dict=config_dict + ) super(Task, self)._set_model_design(design=design) def _get_model_config_text(self): @@ -4132,7 +4493,11 @@ class Task(_Task): def _set_startup_info(self): # type: () -> () self._set_runtime_properties( - runtime_properties={"CLEARML VERSION": self.session.client, "CLI": sys.argv[0], "progress": "0"} + runtime_properties={ + "CLEARML VERSION": self.session.client, + "CLI": sys.argv[0], + "progress": "0", + } ) @classmethod @@ -4165,18 +4530,24 @@ class Task(_Task): ): if not default_project_name or not default_task_name: # get project name and task name from repository name and entry_point - result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False) + result, _ = ScriptInfo.get( + create_requirements=False, check_uncommitted=False + ) if not default_project_name: # noinspection PyBroadException try: parts = result.script["repository"].split("/") - default_project_name = (parts[-1] or parts[-2]).replace(".git", "") or "Untitled" + default_project_name = (parts[-1] or parts[-2]).replace( + ".git", "" + ) or "Untitled" except Exception: default_project_name = "Untitled" if not default_task_name: # noinspection PyBroadException try: - default_task_name = os.path.splitext(os.path.basename(result.script["entry_point"]))[0] + default_task_name = os.path.splitext( + os.path.basename(result.script["entry_point"]) + )[0] except Exception: pass @@ -4192,7 +4563,11 @@ class Task(_Task): continue_last_task = TASK_SET_ITERATION_OFFSET.get() # if we force no task reuse from os environment - if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id or isinstance(reuse_last_task_id, str): + if ( + DEV_TASK_NO_REUSE.get() + or not reuse_last_task_id + or isinstance(reuse_last_task_id, str) + ): default_task = None else: # if we have a previous session to use, get the task id from it @@ -4226,7 +4601,10 @@ class Task(_Task): # instead of resting the previously used task we are continuing the training with it. if task and ( continue_last_task - or (isinstance(continue_last_task, int) and not isinstance(continue_last_task, bool)) + or ( + isinstance(continue_last_task, int) + and not isinstance(continue_last_task, bool) + ) ): task.reload() task.mark_started(force=True) @@ -4237,12 +4615,24 @@ class Task(_Task): task.set_initial_iteration(continue_last_task) else: - task_tags = task.data.system_tags if hasattr(task.data, "system_tags") else task.data.tags + task_tags = ( + task.data.system_tags + if hasattr(task.data, "system_tags") + else task.data.tags + ) task_artifacts = ( - task.data.execution.artifacts if hasattr(task.data.execution, "artifacts") else None + task.data.execution.artifacts + if hasattr(task.data.execution, "artifacts") + else None ) if ( - (task._status in (cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed)) + ( + task._status + in ( + cls.TaskStatusEnum.published, + cls.TaskStatusEnum.closed, + ) + ) or task.output_models_id or (cls.archived_tag in task_tags) or (cls._development_tag not in task_tags) @@ -4263,7 +4653,9 @@ class Task(_Task): # clear the heaviest stuff first task._clear_task( system_tags=[cls._development_tag], - comment=make_message("Auto-generated at %(time)s by %(user)s@%(host)s"), + comment=make_message( + "Auto-generated at %(time)s by %(user)s@%(host)s" + ), ) except (Exception, ValueError): @@ -4284,7 +4676,12 @@ class Task(_Task): if in_dev_mode: # update this session, for later use - cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id) + cls.__update_last_used_task_id( + default_project_name, + default_task_name, + default_task_type.value, + task.id, + ) # set default docker image from env. task._set_default_docker_image() @@ -4306,13 +4703,20 @@ class Task(_Task): # force update of base logger to this current task (this is the main logger task) logger = task._get_logger(auto_connect_streams=auto_connect_streams) if closed_old_task: - logger.report_text("ClearML Task: Closing old development task id={}".format(default_task.get("id"))) + logger.report_text( + "ClearML Task: Closing old development task id={}".format( + default_task.get("id") + ) + ) # print warning, reusing/creating a task if default_task_id and not continue_last_task: - logger.report_text("ClearML Task: overwriting (reusing) task id=%s" % task.id) + logger.report_text( + "ClearML Task: overwriting (reusing) task id=%s" % task.id + ) elif default_task_id and continue_last_task: logger.report_text( - "ClearML Task: continuing previous task id=%s " "Notice this run will not be reproducible!" % task.id + "ClearML Task: continuing previous task id=%s " + "Notice this run will not be reproducible!" % task.id ) else: logger.report_text("ClearML Task: created new task id=%s" % task.id) @@ -4333,7 +4737,9 @@ class Task(_Task): except Exception: pass if in_dev_mode and cls.__detect_repo_async: - task._detect_repo_async_thread = threading.Thread(target=task._update_repository) + task._detect_repo_async_thread = threading.Thread( + target=task._update_repository + ) task._detect_repo_async_thread.daemon = True task._detect_repo_async_thread.start() else: @@ -4370,10 +4776,17 @@ class Task(_Task): self._logger = Logger( private_task=self, connect_stdout=(auto_connect_streams is True) - or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stdout", False)), + or ( + isinstance(auto_connect_streams, dict) + and auto_connect_streams.get("stdout", False) + ), connect_stderr=(auto_connect_streams is True) - or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stderr", False)), - connect_logging=isinstance(auto_connect_streams, dict) and auto_connect_streams.get("logging", False), + or ( + isinstance(auto_connect_streams, dict) + and auto_connect_streams.get("stderr", False) + ), + connect_logging=isinstance(auto_connect_streams, dict) + and auto_connect_streams.get("logging", False), ) # make sure we set our reported to async mode # we make sure we flush it in self._at_exit @@ -4414,7 +4827,9 @@ class Task(_Task): value_type=bool, ) elif not self.running_locally(): - ignore_remote_overrides = self.get_parameter(overrides_name, default=ignore_remote_overrides, cast=True) + ignore_remote_overrides = self.get_parameter( + overrides_name, default=ignore_remote_overrides, cast=True + ) return ignore_remote_overrides def _reconnect_output_model(self): @@ -4448,7 +4863,13 @@ class Task(_Task): return model def _connect_argparse( - self, parser, args=None, namespace=None, parsed_args=None, name=None, ignore_remote_overrides=False + self, + parser, + args=None, + namespace=None, + parsed_args=None, + name=None, + ignore_remote_overrides=False, ): # do not allow argparser to connect to jupyter notebook # noinspection PyBroadException @@ -4481,10 +4902,16 @@ class Task(_Task): if parsed_args is None and parser == _parser: parsed_args = _parsed_args - if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides: + if ( + running_remotely() + and (self.is_main_task() or self._is_remote_main_task()) + and not ignore_remote_overrides + ): self._arguments.copy_to_parser(parser, parsed_args) else: - self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args) + self._arguments.copy_defaults_from_argparse( + parser, args=args, namespace=namespace, parsed_args=parsed_args + ) return parser def _connect_dictionary(self, dictionary, name=None, ignore_remote_overrides=False): @@ -4495,11 +4922,15 @@ class Task(_Task): def _refresh_args_dict(task, config_proxy_dict): # reread from task including newly added keys # noinspection PyProtectedMember - a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_proxy_dict), prefix=name) + a_flat_dict = task._arguments.copy_to_dict( + flatten_dictionary(config_proxy_dict), prefix=name + ) # noinspection PyProtectedMember nested_dict = config_proxy_dict._to_dict() config_proxy_dict.clear() - config_proxy_dict._do_update(nested_from_flat_dictionary(nested_dict, a_flat_dict)) + config_proxy_dict._do_update( + nested_from_flat_dictionary(nested_dict, a_flat_dict) + ) def _check_keys(dict_, warning_sent=False): if warning_sent: @@ -4534,17 +4965,27 @@ class Task(_Task): return dictionary - def _connect_task_parameters(self, attr_class, name=None, ignore_remote_overrides=False): + def _connect_task_parameters( + self, attr_class, name=None, ignore_remote_overrides=False + ): ignore_remote_overrides_section = "_ignore_remote_overrides_" if running_remotely(): ignore_remote_overrides = self.get_parameter( - (name or "General") + "/" + ignore_remote_overrides_section, default=ignore_remote_overrides, cast=True + (name or "General") + "/" + ignore_remote_overrides_section, + default=ignore_remote_overrides, + cast=True, ) - if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides: + if ( + running_remotely() + and (self.is_main_task() or self._is_remote_main_task()) + and not ignore_remote_overrides + ): parameters = self.get_parameters(cast=True) if name: parameters = dict( - (k[len(name) + 1 :], v) for k, v in parameters.items() if k.startswith("{}/".format(name)) + (k[len(name) + 1 :], v) + for k, v in parameters.items() + if k.startswith("{}/".format(name)) ) parameters.pop(ignore_remote_overrides_section, None) attr_class.update_from_dict(parameters) @@ -4557,7 +4998,9 @@ class Task(_Task): def _connect_object(self, an_object, name=None, ignore_remote_overrides=False): def verify_type(key, value): - if str(key).startswith("_") or not isinstance(value, self._parameters_allowed_types): + if str(key).startswith("_") or not isinstance( + value, self._parameters_allowed_types + ): return False # verify everything is json able (i.e. basic types) try: @@ -4572,15 +5015,23 @@ class Task(_Task): for k, v in cls_.__dict__.items() if verify_type(k, v) } - if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides: - a_dict = self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides) + if ( + running_remotely() + and (self.is_main_task() or self._is_remote_main_task()) + and not ignore_remote_overrides + ): + a_dict = self._connect_dictionary( + a_dict, name, ignore_remote_overrides=ignore_remote_overrides + ) for k, v in a_dict.items(): if getattr(an_object, k, None) != a_dict[k]: setattr(an_object, k, v) return an_object else: - self._connect_dictionary(a_dict, name, ignore_remote_overrides=ignore_remote_overrides) + self._connect_dictionary( + a_dict, name, ignore_remote_overrides=ignore_remote_overrides + ) return an_object def _dev_mode_stop_task(self, stop_reason, pid=None): @@ -4588,7 +5039,11 @@ class Task(_Task): if self._at_exit_called: return - self.log.warning("### TASK STOPPED - USER ABORTED - {} ###".format(stop_reason.upper().replace("_", " "))) + self.log.warning( + "### TASK STOPPED - USER ABORTED - {} ###".format( + stop_reason.upper().replace("_", " ") + ) + ) self.flush(wait_for_uploads=True) # if running remotely, we want the daemon to kill us @@ -4601,13 +5056,19 @@ class Task(_Task): # NOTICE! This will end the entire execution tree! if self.__exit_hook: self.__exit_hook.remote_user_aborted = True - self._kill_all_child_processes(send_kill=False, pid=pid, allow_kill_calling_pid=False) + self._kill_all_child_processes( + send_kill=False, pid=pid, allow_kill_calling_pid=False + ) time.sleep(2.0) - self._kill_all_child_processes(send_kill=True, pid=pid, allow_kill_calling_pid=True) + self._kill_all_child_processes( + send_kill=True, pid=pid, allow_kill_calling_pid=True + ) os._exit(1) # noqa @staticmethod - def _kill_all_child_processes(send_kill=False, pid=None, allow_kill_calling_pid=True): + def _kill_all_child_processes( + send_kill=False, pid=None, allow_kill_calling_pid=True + ): # get current process if pid not provided current_pid = os.getpid() kill_ourselves = None @@ -4677,19 +5138,24 @@ class Task(_Task): kill_thread(self._detect_repo_async_thread) else: - self.log.info("Waiting for repository detection and full package requirement analysis") + self.log.info( + "Waiting for repository detection and full package requirement analysis" + ) self._detect_repo_async_thread.join(timeout=timeout) # because join has no return value if self._detect_repo_async_thread.is_alive(): self.log.info( - "Repository and package analysis timed out ({} sec), " "giving up".format(timeout) + "Repository and package analysis timed out ({} sec), " + "giving up".format(timeout) ) # done waiting, kill the thread from .utilities.lowlevel.threads import kill_thread kill_thread(self._detect_repo_async_thread) else: - self.log.info("Finished repository detection and package analysis") + self.log.info( + "Finished repository detection and package analysis" + ) self._detect_repo_async_thread = None except Exception: pass @@ -4732,7 +5198,9 @@ class Task(_Task): if self._at_exit_called: is_sub_process = self.__is_subprocess() # if we are called twice (signal in the middle of the shutdown), - _nested_shutdown_call = bool(self._at_exit_called == get_current_thread_id()) + _nested_shutdown_call = bool( + self._at_exit_called == get_current_thread_id() + ) if _nested_shutdown_call and not is_sub_process: # if we were called again in the main thread on the main process, let's try again # make sure we only do this once @@ -4792,17 +5260,25 @@ class Task(_Task): and is_exception != KeyboardInterrupt ) or ( not self.__exit_hook.remote_user_aborted - and (self.__exit_hook.signal not in (None, 2, 15) or self.__exit_hook.exit_code) + and ( + self.__exit_hook.signal not in (None, 2, 15) + or self.__exit_hook.exit_code + ) ): task_status = ( "failed", - "Exception {}".format(is_exception) - if is_exception - else "Signal {}".format(self.__exit_hook.signal), + ( + "Exception {}".format(is_exception) + if is_exception + else "Signal {}".format(self.__exit_hook.signal) + ), ) wait_for_uploads = False else: - wait_for_uploads = self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None + wait_for_uploads = ( + self.__exit_hook.remote_user_aborted + or self.__exit_hook.signal is None + ) if ( not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None @@ -4829,7 +5305,8 @@ class Task(_Task): # wait for uploads print_done_waiting = False if wait_for_uploads and ( - BackendModel.get_num_results() > 0 or (self.__reporter and self.__reporter.events_waiting()) + BackendModel.get_num_results() > 0 + or (self.__reporter and self.__reporter.events_waiting()) ): self.log.info("Waiting to finish uploads") print_done_waiting = True @@ -4872,7 +5349,9 @@ class Task(_Task): if self._logger: self._logger.set_flush_period(None) # noinspection PyProtectedMember - self._logger._close_stdout_handler(wait=wait_for_uploads or wait_for_std_log) + self._logger._close_stdout_handler( + wait=wait_for_uploads or wait_for_std_log + ) if not is_sub_process: # change task status @@ -4904,10 +5383,14 @@ class Task(_Task): # create zip file offline_folder = self.get_offline_mode_folder() zip_file = offline_folder.as_posix() + ".zip" - with ZipFile(zip_file, "w", allowZip64=True, compression=ZIP_DEFLATED) as zf: + with ZipFile( + zip_file, "w", allowZip64=True, compression=ZIP_DEFLATED + ) as zf: for filename in offline_folder.rglob("*"): if filename.is_file(): - relative_file_name = filename.relative_to(offline_folder).as_posix() + relative_file_name = filename.relative_to( + offline_folder + ).as_posix() zf.write(filename.as_posix(), arcname=relative_file_name) print("ClearML Task: Offline session stored in {}".format(zip_file)) except Exception: @@ -4965,11 +5448,18 @@ class Task(_Task): # type: (...) -> TaskInstance if task_id: - return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) + return cls( + private=cls.__create_protection, task_id=task_id, log_to_backend=False + ) if project_name: - res = cls._send(cls._get_default_session(), projects.GetAllRequest(name=exact_match_regex(project_name))) - project = get_single_result(entity="project", query=project_name, results=res.response.projects) + res = cls._send( + cls._get_default_session(), + projects.GetAllRequest(name=exact_match_regex(project_name)), + ) + project = get_single_result( + entity="project", query=project_name, results=res.response.projects + ) else: project = None @@ -4978,7 +5468,9 @@ class Task(_Task): system_tags = "system_tags" if hasattr(tasks.Task, "system_tags") else "tags" task_filter = task_filter or {} if not include_archived: - task_filter["system_tags"] = (task_filter.get("system_tags") or []) + ["-{}".format(cls.archived_tag)] + task_filter["system_tags"] = (task_filter.get("system_tags") or []) + [ + "-{}".format(cls.archived_tag) + ] if tags: task_filter["tags"] = (task_filter.get("tags") or []) + list(tags) res = cls._send( @@ -4997,7 +5489,8 @@ class Task(_Task): filtered_tasks = [ t for t in res_tasks - if not getattr(t, system_tags, None) or cls.archived_tag not in getattr(t, system_tags, None) + if not getattr(t, system_tags, None) + or cls.archived_tag not in getattr(t, system_tags, None) ] # if we did not filter everything (otherwise we have only archived tasks, so we return them) if filtered_tasks: @@ -5042,10 +5535,20 @@ class Task(_Task): if task_ids: if isinstance(task_ids, six.string_types): task_ids = [task_ids] - return [cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) for task_id in task_ids] + return [ + cls( + private=cls.__create_protection, + task_id=task_id, + log_to_backend=False, + ) + for task_id in task_ids + ] queried_tasks = cls._query_tasks( - project_name=project_name, task_name=task_name, fetch_only_first_page=True, **kwargs + project_name=project_name, + task_name=task_name, + fetch_only_first_page=True, + **kwargs ) if len(queried_tasks) == 500: LoggerRoot.get_base_logger().warning( @@ -5053,7 +5556,10 @@ class Task(_Task): " Returning only the first 500 results." " Use Task.query_tasks() to fetch all task IDs" ) - return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in queried_tasks] + return [ + cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) + for task in queried_tasks + ] @classmethod def _query_tasks( @@ -5087,17 +5593,24 @@ class Task(_Task): res = cls._send( cls._get_default_session(), projects.GetAllRequest( - name=exact_match_regex(name) if exact_match_regex_flag else name, **aux_kwargs + name=( + exact_match_regex(name) if exact_match_regex_flag else name + ), + **aux_kwargs ), ) if res.response and res.response.projects: - project_ids.extend([project.id for project in res.response.projects]) + project_ids.extend( + [project.id for project in res.response.projects] + ) else: projects_not_found.append(name) if projects_not_found: # If any of the given project names does not exist, fire off a warning LoggerRoot.get_base_logger().warning( - "No projects were found with name(s): {}".format(", ".join(projects_not_found)) + "No projects were found with name(s): {}".format( + ", ".join(projects_not_found) + ) ) if not project_ids: # If not a single project exists or was found, return empty right away @@ -5117,7 +5630,9 @@ class Task(_Task): ret_tasks = [] page = -1 page_size = 500 - while page == -1 or (not fetch_only_first_page and res and len(res.response.tasks) == page_size): + while page == -1 or ( + not fetch_only_first_page and res and len(res.response.tasks) == page_size + ): page += 1 # work on a copy and make sure we override all fields with ours request_kwargs = dict( @@ -5160,8 +5675,15 @@ class Task(_Task): return ":".join(map(normalize, args)) @classmethod - def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type): - hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) + def __get_last_used_task_id( + cls, default_project_name, default_task_name, default_task_type + ): + hash_key = cls.__get_hash_key( + cls._get_api_server(), + default_project_name, + default_task_name, + default_task_type, + ) # check if we have a cached task_id we can reuse # it must be from within the last 24h and with the same project/name/type @@ -5185,8 +5707,15 @@ class Task(_Task): return task_data @classmethod - def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id): - hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) + def __update_last_used_task_id( + cls, default_project_name, default_task_name, default_task_type, task_id + ): + hash_key = cls.__get_hash_key( + cls._get_api_server(), + default_project_name, + default_task_name, + default_task_type, + ) task_id = str(task_id) # update task session cache @@ -5201,7 +5730,9 @@ class Task(_Task): # remove stale sessions for k in list(task_sessions.keys()): - if (time.time() - task_sessions[k].get("time", 0)) > 60 * 60 * cls.__task_id_reuse_time_window_in_hours: + if ( + time.time() - task_sessions[k].get("time", 0) + ) > 60 * 60 * cls.__task_id_reuse_time_window_in_hours: task_sessions.pop(k) # update current session task_sessions[hash_key] = last_task_session @@ -5214,7 +5745,8 @@ class Task(_Task): task_data and task_data.get("id") and task_data.get("time") - and (time.time() - task_data.get("time")) > (60 * 60 * cls.__task_id_reuse_time_window_in_hours) + and (time.time() - task_data.get("time")) + > (60 * 60 * cls.__task_id_reuse_time_window_in_hours) ) @classmethod @@ -5273,7 +5805,8 @@ class Task(_Task): # noinspection PyBroadException try: project = cls._send( - cls._get_default_session(), projects.GetByIdRequest(project=task.project) + cls._get_default_session(), + projects.GetByIdRequest(project=task.project), ).response.project if project: @@ -5283,13 +5816,16 @@ class Task(_Task): if ( task_data.get("type") - and task_data.get("type") not in (cls.TaskTypes.training, cls.TaskTypes.testing) + and task_data.get("type") + not in (cls.TaskTypes.training, cls.TaskTypes.testing) and not Session.check_min_api_version(2.8) ): print( 'WARNING: Changing task type to "{}" : ' 'clearml-server does not support task type "{}", ' - "please upgrade clearml-server.".format(cls.TaskTypes.training, task_data["type"].value) + "please upgrade clearml-server.".format( + cls.TaskTypes.training, task_data["type"].value + ) ) task_data["type"] = cls.TaskTypes.training @@ -5328,7 +5864,11 @@ class Task(_Task): if task.status not in stopped_statuses: cls._send( cls._get_default_session(), - tasks.StoppedRequest(task=task.id, force=True, status_message="Stopped timed out development task"), + tasks.StoppedRequest( + task=task.id, + force=True, + status_message="Stopped timed out development task", + ), ) return True @@ -5361,7 +5901,11 @@ class Task(_Task): def __getstate__(self): # type: () -> dict - return {"main": self.is_main_task(), "id": self.id, "offline": self.is_offline()} + return { + "main": self.is_main_task(), + "id": self.id, + "offline": self.is_offline(), + } def __setstate__(self, state): if state["main"] and not self.__main_task: @@ -5371,7 +5915,10 @@ class Task(_Task): Task.set_offline(offline_mode=state["offline"]) task = ( - Task.init(continue_last_task=state["id"], auto_connect_frameworks={"detect_repository": False}) + Task.init( + continue_last_task=state["id"], + auto_connect_frameworks={"detect_repository": False}, + ) if state["main"] else Task.get_task(task_id=state["id"]) )