diff --git a/clearml/task.py b/clearml/task.py index e6b0cfd7..ac65b785 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -39,9 +39,19 @@ from pathlib2 import Path from .backend_config.defs import get_active_config_file, get_config_file from .backend_api.services import tasks, projects, events, queues from .backend_api.session.session import ( - Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, ) -from .backend_api.session.defs import (ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG, - ENV_OFFLINE_MODE, MissingConfigError) + Session, + ENV_ACCESS_KEY, + ENV_SECRET_KEY, + ENV_HOST, + ENV_WEB_HOST, + ENV_FILES_HOST, +) +from .backend_api.session.defs import ( + ENV_DEFERRED_TASK_INIT, + ENV_IGNORE_MISSING_CONFIG, + ENV_OFFLINE_MODE, + MissingConfigError, +) from .backend_interface.metrics import Metrics from .backend_interface.model import Model as BackendModel from .backend_interface.base import InterfaceBase @@ -78,8 +88,15 @@ from .binding.jsonargs_bind import PatchJsonArgParse from .binding.gradio_bind import PatchGradio from .binding.frameworks import WeightsFileHandler from .config import ( - config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI, - deferred_config, TASK_SET_ITERATION_OFFSET, HOST_MACHINE_IP) + config, + DEV_TASK_NO_REUSE, + get_is_master_node, + DEBUG_SIMULATE_REMOTE_TASK, + DEV_DEFAULT_OUTPUT_URI, + deferred_config, + TASK_SET_ITERATION_OFFSET, + HOST_MACHINE_IP, +) from .config import running_remotely, get_remote_task_id from .config.cache import SessionCache from .debugging.log import LoggerRoot @@ -89,22 +106,33 @@ from .model import Model, InputModel, OutputModel, Framework from .task_parameters import TaskParameters from .utilities.config import verify_basic_value from .binding.args import ( - argparser_parseargs_called, get_argparser_last_args, - argparser_update_currenttask, ) + argparser_parseargs_called, + get_argparser_last_args, + argparser_update_currenttask, +) from .utilities.dicts import ReadOnlyDict, merge_dicts, RequirementsDict from .utilities.proxy_object import ( - ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, - nested_from_flat_dictionary, naive_nested_from_flat_dictionary, StubObject as _TaskStub) + ProxyDictPreWrite, + ProxyDictPostWrite, + flatten_dictionary, + nested_from_flat_dictionary, + naive_nested_from_flat_dictionary, + StubObject as _TaskStub, +) from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic from .utilities.lowlevel.threads import get_current_thread_id -from .utilities.lowlevel.distributed import get_torch_local_rank, get_torch_distributed_anchor_task_id, \ - create_torch_distributed_anchor +from .utilities.lowlevel.distributed import ( + get_torch_local_rank, + get_torch_distributed_anchor_task_id, + create_torch_distributed_anchor, +) from .utilities.process.mp import BackgroundMonitor, leave_process from .utilities.process.exit_hooks import ExitHooks from .utilities.matching import matches_any_wildcard from .utilities.parallel import FutureTaskCaller from .utilities.networking import get_private_ip + # noinspection PyProtectedMember from .backend_interface.task.args import _Arguments @@ -177,9 +205,9 @@ class Task(_Task): __main_task = None # type: Optional[Task] __exit_hook = None __forked_proc_main_pid = None - __task_id_reuse_time_window_in_hours = deferred_config('development.task_reuse_time_window_in_hours', 24.0, float) - __detect_repo_async = deferred_config('development.vcs_repo_detect_async', False) - __default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None) + __task_id_reuse_time_window_in_hours = deferred_config("development.task_reuse_time_window_in_hours", 24.0, float) + __detect_repo_async = deferred_config("development.vcs_repo_detect_async", False) + __default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config("development.default_output_uri", None) __hidden_tag = "hidden" @@ -199,10 +227,7 @@ class Task(_Task): @classmethod def _options(cls): - return { - var for var, val in vars(cls).items() - if isinstance(val, six.string_types) - } + return {var for var, val in vars(cls).items() if isinstance(val, six.string_types)} def __init__(self, private=None, **kwargs): """ @@ -212,7 +237,8 @@ class Task(_Task): """ if private is not Task.__create_protection: raise UsageError( - 'Task object cannot be instantiated externally, use Task.current_task() or Task.get_task(...)') + "Task object cannot be instantiated externally, use Task.current_task() or Task.get_task(...)" + ) self._repo_detect_lock = threading.RLock() super(Task, self).__init__(**kwargs) @@ -249,19 +275,19 @@ class Task(_Task): @classmethod def init( - cls, - project_name=None, # type: Optional[str] - task_name=None, # type: Optional[str] - task_type=TaskTypes.training, # type: Task.TaskTypes - tags=None, # type: Optional[Sequence[str]] - reuse_last_task_id=True, # type: Union[bool, str] - continue_last_task=False, # type: Union[bool, str, int] - output_uri=None, # type: Optional[Union[str, bool]] - auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]] - auto_connect_frameworks=True, # type: Union[bool, Mapping[str, Union[bool, str, list]]] - auto_resource_monitoring=True, # type: Union[bool, Mapping[str, Any]] - auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]] - deferred_init=False, # type: bool + cls, + project_name=None, # type: Optional[str] + task_name=None, # type: Optional[str] + task_type=TaskTypes.training, # type: Task.TaskTypes + tags=None, # type: Optional[Sequence[str]] + reuse_last_task_id=True, # type: Union[bool, str] + continue_last_task=False, # type: Union[bool, str, int] + output_uri=None, # type: Optional[Union[str, bool]] + auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]] + auto_connect_frameworks=True, # type: Union[bool, Mapping[str, Union[bool, str, list]]] + auto_resource_monitoring=True, # type: Union[bool, Mapping[str, Any]] + auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]] + deferred_init=False, # type: bool ): # type: (...) -> TaskInstance """ @@ -473,9 +499,9 @@ class Task(_Task): def verify_defaults_match(): validate = [ - ('project name', project_name, cls.__main_task.get_project_name()), - ('task name', task_name, cls.__main_task.name), - ('task type', str(task_type) if task_type else task_type, str(cls.__main_task.task_type)), + ("project name", project_name, cls.__main_task.get_project_name()), + ("task name", task_name, cls.__main_task.name), + ("task type", str(task_type) if task_type else task_type, str(cls.__main_task.task_type)), ] for field, default, current in validate: @@ -506,7 +532,7 @@ class Task(_Task): # if we are using threads to send the reports, # after forking there are no threads, so we will need to recreate them - if not getattr(cls, '_report_subprocess_enabled'): + if not getattr(cls, "_report_subprocess_enabled"): # remove the logger from the previous process cls.__main_task.get_logger() # create a new logger (to catch stdout/err) @@ -522,7 +548,7 @@ class Task(_Task): # if we are using threads to send the reports, # after forking there are no threads, so we will need to recreate them - if not getattr(cls, '_report_subprocess_enabled'): + if not getattr(cls, "_report_subprocess_enabled"): # start all reporting threads BackgroundMonitor.start_all(task=cls.__main_task) @@ -556,8 +582,9 @@ class Task(_Task): task_type = cls.TaskTypes.training elif isinstance(task_type, six.string_types): if task_type not in Task.TaskTypes.__members__: - raise ValueError("Task type '{}' not supported, options are: {}".format( - task_type, Task.TaskTypes.__members__.keys())) + raise ValueError( + "Task type '{}' not supported, options are: {}".format(task_type, Task.TaskTypes.__members__.keys()) + ) task_type = Task.TaskTypes.__members__[str(task_type)] is_deferred = False @@ -574,6 +601,7 @@ class Task(_Task): deferred_init = True if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag: + def completed_cb(x): Task.__forked_proc_main_pid = os.getpid() Task.__main_task = x @@ -609,9 +637,12 @@ class Task(_Task): tags=tags, reuse_last_task_id=reuse_last_task_id, continue_last_task=continue_last_task, - detect_repo=False if ( - isinstance(auto_connect_frameworks, dict) and - not auto_connect_frameworks.get('detect_repository', True)) else True, + detect_repo=False + if ( + isinstance(auto_connect_frameworks, dict) + and not auto_connect_frameworks.get("detect_repository", True) + ) + else True, auto_connect_streams=auto_connect_streams, ) # check if we are local rank 0 (local master), @@ -691,6 +722,7 @@ class Task(_Task): # always patch OS forking because of ProcessPool and the alike PatchOsFork.patch_fork(task) if auto_connect_frameworks: + def should_connect(*keys): """ Evaluates value of auto_connect_frameworks[keys[0]]...[keys[-1]]. @@ -743,26 +775,26 @@ class Task(_Task): # if we are deferred, stop here (the rest we do in the actual init) if is_deferred: from .backend_interface.logger import StdStreamPatch + # patch console outputs, we will keep them in memory until we complete the Task init # notice we do not load config defaults, as they are not threadsafe # we might also need to override them with the vault StdStreamPatch.patch_std_streams( task.get_logger(), - connect_stdout=( - auto_connect_streams is True) or ( - isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stdout', False) - ), - connect_stderr=( - auto_connect_streams is True) or ( - isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stderr', False) - ), + connect_stdout=(auto_connect_streams is True) + or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stdout", False)), + connect_stderr=(auto_connect_streams is True) + or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stderr", False)), load_config_defaults=False, ) return task # noqa if auto_resource_monitoring and not is_sub_process_task_id: - resource_monitor_cls = auto_resource_monitoring \ - if isinstance(auto_resource_monitoring, six.class_types) else ResourceMonitor + resource_monitor_cls = ( + auto_resource_monitoring + if isinstance(auto_resource_monitoring, six.class_types) + else ResourceMonitor + ) resource_monitor_kwargs = dict( report_mem_used_per_process=not config.get("development.worker.report_global_mem_used", False), first_report_sec=config.get("development.worker.report_start_sec", None), @@ -785,10 +817,7 @@ class Task(_Task): "report_global_mem_used" ) resource_monitor_kwargs.update(auto_resource_monitoring) - task._resource_monitor = resource_monitor_cls( - task, - **resource_monitor_kwargs - ) + task._resource_monitor = resource_monitor_cls(task, **resource_monitor_kwargs) task._resource_monitor.start() # make sure all random generators are initialized with new seed @@ -831,16 +860,20 @@ class Task(_Task): # show the debug metrics page in the log, it is very convenient if not is_sub_process_task_id: if cls._offline_mode: - logger.report_text('ClearML running in offline mode, session stored in {}'.format( - task.get_offline_mode_folder())) + logger.report_text( + "ClearML running in offline mode, session stored in {}".format(task.get_offline_mode_folder()) + ) else: - logger.report_text('ClearML results page: {}'.format(task.get_output_log_web_page())) + logger.report_text("ClearML results page: {}".format(task.get_output_log_web_page())) # Make sure we start the dev worker if required, otherwise it will only be started when we write # something to the log. task._dev_mode_setup_worker() - if (not task._reporter or not task._reporter.is_constructed()) and \ - is_sub_process_task_id and not cls._report_subprocess_enabled: + if ( + (not task._reporter or not task._reporter.is_constructed()) + and is_sub_process_task_id + and not cls._report_subprocess_enabled + ): task._setup_reporter() # start monitoring in background process or background threads @@ -930,7 +963,8 @@ class Task(_Task): # notice this applies for both raw tcp and http, it is so that we can # detect the host machine exposed ports, and register them on the router external_host_port_mapping = self._get_runtime_properties().get( - self._external_endpoint_host_tcp_port_mapping["tcp_host_mapping"]) + self._external_endpoint_host_tcp_port_mapping["tcp_host_mapping"] + ) self._external_endpoint_ports["tcp_host_mapping"] = external_host_port_mapping # check if we need to parse the port mapping, only if running on "bare-metal" host machine. @@ -950,17 +984,21 @@ class Task(_Task): in_range = in_range.split("-") if int(in_range[0]) <= port <= int(in_range[-1]): # we found a match: - out_port = int(out_range[0]) + (port-int(in_range[0])) - print("INFO: Task.request_external_endpoint(...) changed requested external port to {}, " - "conforming to mapped external host ports [{} -> {}]".format(out_port, port, port_range)) + out_port = int(out_range[0]) + (port - int(in_range[0])) + print( + "INFO: Task.request_external_endpoint(...) changed requested external port to {}, " + "conforming to mapped external host ports [{} -> {}]".format(out_port, port, port_range) + ) break if not out_port: raise ValueError("match not found defaulting to original port") except Exception: - print("WARNING: Task.request_external_endpoint(...) failed matching requested port to " - "mapped external host port [{} to {}], " - "proceeding with original port {}".format(port, external_host_port_mapping, port)) + print( + "WARNING: Task.request_external_endpoint(...) failed matching requested port to " + "mapped external host port [{} to {}], " + "proceeding with original port {}".format(port, external_host_port_mapping, port) + ) # change the requested port to the one we have on the machine if out_port: @@ -993,7 +1031,7 @@ class Task(_Task): return self.wait_for_external_endpoint( wait_interval_seconds=wait_interval_seconds, wait_timeout_seconds=wait_timeout_seconds, - protocol=protocol + protocol=protocol, ) return None @@ -1023,7 +1061,7 @@ class Task(_Task): wait_interval_seconds=wait_interval_seconds, wait_timeout_seconds=wait_timeout_seconds, protocol=protocol, - warn=True + warn=True, ) results = [] protocols = ["http", "tcp"] @@ -1142,24 +1180,24 @@ class Task(_Task): @classmethod def create( - cls, - project_name=None, # type: Optional[str] - task_name=None, # type: Optional[str] - task_type=None, # type: Optional[str] - repo=None, # type: Optional[str] - branch=None, # type: Optional[str] - commit=None, # type: Optional[str] - script=None, # type: Optional[str] - working_directory=None, # type: Optional[str] - packages=None, # type: Optional[Union[bool, Sequence[str]]] - requirements_file=None, # type: Optional[Union[str, Path]] - docker=None, # type: Optional[str] - docker_args=None, # type: Optional[str] - docker_bash_setup_script=None, # type: Optional[str] - argparse_args=None, # type: Optional[Sequence[Tuple[str, str]]] - base_task_id=None, # type: Optional[str] - add_task_init_call=True, # type: bool - force_single_script_file=False, # type: bool + cls, + project_name=None, # type: Optional[str] + task_name=None, # type: Optional[str] + task_type=None, # type: Optional[str] + repo=None, # type: Optional[str] + branch=None, # type: Optional[str] + commit=None, # type: Optional[str] + script=None, # type: Optional[str] + working_directory=None, # type: Optional[str] + packages=None, # type: Optional[Union[bool, Sequence[str]]] + requirements_file=None, # type: Optional[Union[str, Path]] + docker=None, # type: Optional[str] + docker_args=None, # type: Optional[str] + docker_bash_setup_script=None, # type: Optional[str] + argparse_args=None, # type: Optional[Sequence[Tuple[str, str]]] + base_task_id=None, # type: Optional[str] + add_task_init_call=True, # type: bool + force_single_script_file=False, # type: bool ): # type: (...) -> TaskInstance """ @@ -1211,16 +1249,27 @@ class Task(_Task): raise UsageError("Creating task in offline mode. Use 'Task.init' instead.") if not project_name and not base_task_id: if not cls.__main_task: - raise ValueError("Please provide project_name, no global task context found " - "(Task.current_task hasn't been called)") + raise ValueError( + "Please provide project_name, no global task context found " + "(Task.current_task hasn't been called)" + ) project_name = cls.__main_task.get_project_name() from .backend_interface.task.populate import CreateAndPopulate + manual_populate = CreateAndPopulate( - project_name=project_name, task_name=task_name, task_type=task_type, - repo=repo, branch=branch, commit=commit, - script=script, working_directory=working_directory, - packages=packages, requirements_file=requirements_file, - docker=docker, docker_args=docker_args, docker_bash_setup_script=docker_bash_setup_script, + project_name=project_name, + task_name=task_name, + task_type=task_type, + repo=repo, + branch=branch, + commit=commit, + script=script, + working_directory=working_directory, + packages=packages, + requirements_file=requirements_file, + docker=docker, + docker_args=docker_args, + docker_bash_setup_script=docker_bash_setup_script, base_task_id=base_task_id, add_task_init_call=add_task_init_call, force_single_script_file=force_single_script_file, @@ -1252,13 +1301,13 @@ class Task(_Task): @classmethod def get_task( - cls, - task_id=None, # type: Optional[str] - project_name=None, # type: Optional[str] - task_name=None, # type: Optional[str] - tags=None, # type: Optional[Sequence[str]] - allow_archived=True, # type: bool - task_filter=None # type: Optional[dict] + cls, + task_id=None, # type: Optional[str] + project_name=None, # type: Optional[str] + task_name=None, # type: Optional[str] + tags=None, # type: Optional[Sequence[str]] + allow_archived=True, # type: bool + task_filter=None, # type: Optional[dict] ): # type: (...) -> TaskInstance """ @@ -1332,19 +1381,23 @@ class Task(_Task): :rtype: Task """ return cls.__get_task( - task_id=task_id, project_name=project_name, task_name=task_name, tags=tags, - include_archived=allow_archived, task_filter=task_filter, + task_id=task_id, + project_name=project_name, + task_name=task_name, + tags=tags, + include_archived=allow_archived, + task_filter=task_filter, ) @classmethod def get_tasks( - cls, - task_ids=None, # type: Optional[Sequence[str]] - project_name=None, # type: Optional[Union[Sequence[str],str]] - task_name=None, # type: Optional[str] - tags=None, # type: Optional[Sequence[str]] - allow_archived=True, # type: bool - task_filter=None # type: Optional[Dict] + cls, + task_ids=None, # type: Optional[Sequence[str]] + project_name=None, # type: Optional[Union[Sequence[str],str]] + task_name=None, # type: Optional[str] + tags=None, # type: Optional[Sequence[str]] + allow_archived=True, # type: bool + task_filter=None, # type: Optional[Dict] ): # type: (...) -> List[TaskInstance] """ @@ -1414,19 +1467,20 @@ class Task(_Task): """ task_filter = task_filter or {} if not allow_archived: - task_filter['system_tags'] = (task_filter.get('system_tags') or []) + ['-{}'.format(cls.archived_tag)] + task_filter["system_tags"] = (task_filter.get("system_tags") or []) + ["-{}".format(cls.archived_tag)] - return cls.__get_tasks(task_ids=task_ids, project_name=project_name, tags=tags, - task_name=task_name, **task_filter) + return cls.__get_tasks( + task_ids=task_ids, project_name=project_name, tags=tags, task_name=task_name, **task_filter + ) @classmethod def query_tasks( - cls, - project_name=None, # type: Optional[Union[Sequence[str],str]] - task_name=None, # type: Optional[str] - tags=None, # type: Optional[Sequence[str]] - additional_return_fields=None, # type: Optional[Sequence[str]] - task_filter=None, # type: Optional[Dict] + cls, + project_name=None, # type: Optional[Union[Sequence[str],str]] + task_name=None, # type: Optional[str] + tags=None, # type: Optional[Sequence[str]] + additional_return_fields=None, # type: Optional[Sequence[str]] + task_filter=None, # type: Optional[Dict] ): # type: (...) -> Union[List[str], List[Dict[str, str]]] """ @@ -1491,20 +1545,27 @@ class Task(_Task): """ task_filter = task_filter or {} if tags: - task_filter['tags'] = (task_filter.get('tags') or []) + list(tags) + task_filter["tags"] = (task_filter.get("tags") or []) + list(tags) return_fields = {} if additional_return_fields: - return_fields = set(list(additional_return_fields) + ['id']) - task_filter['only_fields'] = (task_filter.get('only_fields') or []) + list(return_fields) + return_fields = set(list(additional_return_fields) + ["id"]) + task_filter["only_fields"] = (task_filter.get("only_fields") or []) + list(return_fields) - if task_filter.get('type'): - task_filter['type'] = [str(task_type) for task_type in task_filter['type']] + if task_filter.get("type"): + task_filter["type"] = [str(task_type) for task_type in task_filter["type"]] results = cls._query_tasks(project_name=project_name, task_name=task_name, **task_filter) - return [t.id for t in results] if not additional_return_fields else \ - [{k: cls._get_data_property(prop_path=k, data=r, raise_on_error=False, log_on_error=False) - for k in return_fields} - for r in results] + return ( + [t.id for t in results] + if not additional_return_fields + else [ + { + k: cls._get_data_property(prop_path=k, data=r, raise_on_error=False, log_on_error=False) + for k in return_fields + } + for r in results + ] + ) @property def output_uri(self): @@ -1545,10 +1606,13 @@ class Task(_Task): # check if we have the correct packages / configuration if value and value != self.storage_uri: from .storage.helper import StorageHelper + helper = StorageHelper.get(value) if not helper: - raise ValueError("Could not get access credentials for '{}' " - ", check configuration file ~/clearml.conf".format(value)) + raise ValueError( + "Could not get access credentials for '{}' " + ", check configuration file ~/clearml.conf".format(value) + ) helper.check_write_permissions(value) self.storage_uri = value @@ -1560,7 +1624,7 @@ class Task(_Task): :return: The artifacts. """ - if not Session.check_min_api_version('2.3'): + if not Session.check_min_api_version("2.3"): return ReadOnlyDict() artifacts_pairs = [] if self.data.execution and self.data.execution.artifacts: @@ -1617,12 +1681,12 @@ class Task(_Task): @classmethod def clone( - cls, - source_task=None, # type: Optional[Union[Task, str]] - name=None, # type: Optional[str] - comment=None, # type: Optional[str] - parent=None, # type: Optional[str] - project=None, # type: Optional[str] + cls, + source_task=None, # type: Optional[Union[Task, str]] + name=None, # type: Optional[str] + comment=None, # type: Optional[str] + parent=None, # type: Optional[str] + project=None, # type: Optional[str] ): # type: (...) -> TaskInstance """ @@ -1646,9 +1710,10 @@ class Task(_Task): :rtype: Task """ assert isinstance(source_task, (six.string_types, Task)) - if not Session.check_min_api_version('2.4'): - raise ValueError("ClearML-server does not support DevOps features, " - "upgrade clearml-server to 0.12.0 or above") + if not Session.check_min_api_version("2.4"): + raise ValueError( + "ClearML-server does not support DevOps features, " "upgrade clearml-server to 0.12.0 or above" + ) task_id = source_task if isinstance(source_task, six.string_types) else source_task.id if not parent: @@ -1658,8 +1723,9 @@ class Task(_Task): elif isinstance(parent, Task): parent = parent.id - cloned_task_id = cls._clone_task(cloned_task_id=task_id, name=name, comment=comment, - parent=parent, project=project) + cloned_task_id = cls._clone_task( + cloned_task_id=task_id, name=name, comment=comment, parent=parent, project=project + ) cloned_task = cls.get_task(task_id=cloned_task_id) return cloned_task @@ -1708,9 +1774,10 @@ class Task(_Task): """ assert isinstance(task, (six.string_types, Task)) - if not Session.check_min_api_version('2.4'): - raise ValueError("ClearML-server does not support DevOps features, " - "upgrade clearml-server to 0.12.0 or above") + if not Session.check_min_api_version("2.4"): + raise ValueError( + "ClearML-server does not support DevOps features, " "upgrade clearml-server to 0.12.0 or above" + ) # make sure we have wither name ot id mutually_exclusive(queue_name=queue_name, queue_id=queue_id) @@ -1809,9 +1876,10 @@ class Task(_Task): """ assert isinstance(task, (six.string_types, Task)) - if not Session.check_min_api_version('2.4'): - raise ValueError("ClearML-server does not support DevOps features, " - "upgrade clearml-server to 0.12.0 or above") + if not Session.check_min_api_version("2.4"): + raise ValueError( + "ClearML-server does not support DevOps features, " "upgrade clearml-server to 0.12.0 or above" + ) task_id = task if isinstance(task, six.string_types) else task.id session = cls._get_default_session() @@ -1908,19 +1976,21 @@ class Task(_Task): (object, self._connect_object), ) - multi_config_support = Session.check_min_api_version('2.9') + multi_config_support = Session.check_min_api_version("2.9") if multi_config_support and not name and not isinstance(mutable, (OutputModel, InputModel)): name = self._default_configuration_section_name if not multi_config_support and name and name != self._default_configuration_section_name: - raise ValueError("Multiple configurations is not supported with the current 'clearml-server', " - "please upgrade to the latest version") + raise ValueError( + "Multiple configurations is not supported with the current 'clearml-server', " + "please upgrade to the latest version" + ) for mutable_type, method in dispatch: if isinstance(mutable, mutable_type): return method(mutable, name=name, ignore_remote_overrides=ignore_remote_overrides) - raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) + raise Exception("Unsupported mutable type %s: no connect function found" % type(mutable).__name__) def set_packages(self, packages): # type: (Union[str, Path, Sequence[str]]) -> () @@ -1938,7 +2008,7 @@ class Task(_Task): """ if running_remotely() or packages is None: return - self._wait_for_repo_detection(timeout=300.) + self._wait_for_repo_detection(timeout=300.0) if packages and isinstance(packages, (str, Path)) and Path(packages).is_file(): with open(Path(packages).as_posix(), "rt") as f: @@ -1972,7 +2042,7 @@ class Task(_Task): """ if running_remotely(): return - self._wait_for_repo_detection(timeout=300.) + self._wait_for_repo_detection(timeout=300.0) with self._edit_lock: self.reload() if repo is not None: @@ -1994,7 +2064,7 @@ class Task(_Task): :return: A `RequirementsDict` object that holds the `pip`, `conda`, `orig_pip` requirements. """ if not running_remotely() and self.is_main_task(): - self._wait_for_repo_detection(timeout=300.) + self._wait_for_repo_detection(timeout=300.0) requirements_dict = RequirementsDict() requirements_dict.update(self.data.script.requirements) return requirements_dict @@ -2052,26 +2122,38 @@ class Task(_Task): except ImportError: pass if not pathlib_Path or not isinstance(configuration, pathlib_Path): - raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, " - "{} is not supported".format(type(configuration))) + raise ValueError( + "connect_configuration supports `dict`, `str` and 'Path' types, " + "{} is not supported".format(type(configuration)) + ) if pathlib_Path and isinstance(configuration, pathlib_Path): cast_Path = pathlib_Path - multi_config_support = Session.check_min_api_version('2.9') + multi_config_support = Session.check_min_api_version("2.9") if multi_config_support and not name: name = self._default_configuration_section_name if not multi_config_support and name and name != self._default_configuration_section_name: - raise ValueError("Multiple configurations is not supported with the current 'clearml-server', " - "please upgrade to the latest version") + raise ValueError( + "Multiple configurations is not supported with the current 'clearml-server', " + "please upgrade to the latest version" + ) # parameter dictionary - if isinstance(configuration, (dict, list,)): + if isinstance( + configuration, + ( + dict, + list, + ), + ): + def _update_config_dict(task, config_dict): if multi_config_support: # noinspection PyProtectedMember task._set_configuration( - name=name, description=description, config_type='dictionary', config_dict=config_dict) + name=name, description=description, config_type="dictionary", config_dict=config_dict + ) else: # noinspection PyProtectedMember task._set_model_config(config_dict=config_dict) @@ -2087,25 +2169,33 @@ class Task(_Task): configuration_ = ProxyDictPostWrite(self, _update_config_dict, configuration_) return configuration_ - if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides: + if ( + not running_remotely() + or not (self.is_main_task() or self._is_remote_main_task()) + or ignore_remote_overrides + ): configuration = get_dev_config(configuration) else: # noinspection PyBroadException try: - remote_configuration = self._get_configuration_dict(name=name) \ - if multi_config_support else self._get_model_config_dict() + remote_configuration = ( + self._get_configuration_dict(name=name) + if multi_config_support + else self._get_model_config_dict() + ) except Exception: remote_configuration = None if remote_configuration is None: LoggerRoot.get_base_logger().warning( - "Could not retrieve remote configuration named \'{}\'\n" - "Using default configuration: {}".format(name, str(configuration))) + "Could not retrieve remote configuration named '{}'\n" + "Using default configuration: {}".format(name, str(configuration)) + ) # update back configuration section if multi_config_support: self._set_configuration( - name=name, description=description, - config_type='dictionary', config_dict=configuration) + name=name, description=description, config_type="dictionary", config_dict=configuration + ) return configuration if not remote_configuration: @@ -2121,51 +2211,67 @@ class Task(_Task): return configuration # it is a path to a local file - if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides: + if ( + not running_remotely() + or not (self.is_main_task() or self._is_remote_main_task()) + or ignore_remote_overrides + ): # check if not absolute path configuration_path = cast_Path(configuration) if not configuration_path.is_file(): ValueError("Configuration file does not exist") try: - with open(configuration_path.as_posix(), 'rt') as f: + with open(configuration_path.as_posix(), "rt") as f: configuration_text = f.read() except Exception: - raise ValueError("Could not connect configuration file {}, file could not be read".format( - configuration_path.as_posix())) + raise ValueError( + "Could not connect configuration file {}, file could not be read".format( + configuration_path.as_posix() + ) + ) if multi_config_support: self._set_configuration( - name=name, description=description, - config_type=configuration_path.suffixes[-1].lstrip('.') - if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file', - config_text=configuration_text) + name=name, + description=description, + config_type=configuration_path.suffixes[-1].lstrip(".") + if configuration_path.suffixes and configuration_path.suffixes[-1] + else "file", + config_text=configuration_text, + ) else: self._set_model_config(config_text=configuration_text) return configuration else: - configuration_text = self._get_configuration_text(name=name) if multi_config_support \ - else self._get_model_config_text() + configuration_text = ( + self._get_configuration_text(name=name) if multi_config_support else self._get_model_config_text() + ) if configuration_text is None: LoggerRoot.get_base_logger().warning( - "Could not retrieve remote configuration named \'{}\'\n" - "Using default configuration: {}".format(name, str(configuration))) + "Could not retrieve remote configuration named '{}'\n" + "Using default configuration: {}".format(name, str(configuration)) + ) # update back configuration section if multi_config_support: configuration_path = cast_Path(configuration) if configuration_path.is_file(): - with open(configuration_path.as_posix(), 'rt') as f: + with open(configuration_path.as_posix(), "rt") as f: configuration_text = f.read() self._set_configuration( - name=name, description=description, - config_type=configuration_path.suffixes[-1].lstrip('.') - if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file', - config_text=configuration_text) + name=name, + description=description, + config_type=configuration_path.suffixes[-1].lstrip(".") + if configuration_path.suffixes and configuration_path.suffixes[-1] + else "file", + config_text=configuration_text, + ) return configuration configuration_path = cast_Path(configuration) - fd, local_filename = mkstemp(prefix='clearml_task_config_', - suffix=configuration_path.suffixes[-1] if - configuration_path.suffixes else '.txt') + fd, local_filename = mkstemp( + prefix="clearml_task_config_", + suffix=configuration_path.suffixes[-1] if configuration_path.suffixes else ".txt", + ) with open(fd, "w") as f: f.write(configuration_text) return cast_Path(local_filename) if isinstance(configuration, cast_Path) else local_filename @@ -2196,8 +2302,9 @@ class Task(_Task): "General/_ignore_remote_overrides_label_enumeration_", ignore_remote_overrides ) if not isinstance(enumeration, dict): - raise ValueError("connect_label_enumeration supports only `dict` type, " - "{} is not supported".format(type(enumeration))) + raise ValueError( + "connect_label_enumeration supports only `dict` type, " "{} is not supported".format(type(enumeration)) + ) if ( not running_remotely() @@ -2230,7 +2337,7 @@ class Task(_Task): wait=False, # type: bool addr=None, # type: Optional[str] devices=None, # type: Optional[Union[int, Sequence[int]]] - hide_children=False # bool + hide_children=False, # bool ): """ Enqueue multiple clones of the current task to a queue, allowing the task @@ -2339,7 +2446,7 @@ class Task(_Task): ), "node_rank": 0, "wait": wait, - "devices": devices + "devices": devices, } editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue} editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section) @@ -2539,7 +2646,7 @@ class Task(_Task): # wait for repository detection (5 minutes should be reasonable time to detect all packages) if self._logger and not self.__is_subprocess(): - self._wait_for_repo_detection(timeout=300.) + self._wait_for_repo_detection(timeout=300.0) self.__shutdown() # unregister atexit callbacks and signal hooks, if we are the main task @@ -2570,11 +2677,11 @@ class Task(_Task): PatchOsFork.patch_fork(None) def delete( - self, - delete_artifacts_and_models=True, - skip_models_used_by_other_tasks=True, - raise_on_error=False, - callback=None, + self, + delete_artifacts_and_models=True, + skip_models_used_by_other_tasks=True, + raise_on_error=False, + callback=None, ): # type: (bool, bool, bool, Callable[[str, str], bool]) -> bool """ @@ -2632,11 +2739,12 @@ class Task(_Task): which is the same as ``artifact.columns``. """ if not isinstance(uniqueness_columns, CollectionsSequence) and uniqueness_columns is not True: - raise ValueError('uniqueness_columns should be a List (sequence) or True') + raise ValueError("uniqueness_columns should be a List (sequence) or True") if isinstance(uniqueness_columns, str): uniqueness_columns = [uniqueness_columns] self._artifacts_manager.register_artifact( - name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) + name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns + ) def unregister_artifact(self, name): # type: (str) -> None @@ -2664,17 +2772,17 @@ class Task(_Task): return self._artifacts_manager.registered_artifacts def upload_artifact( - self, - name, # type: str - artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image, Any] - metadata=None, # type: Optional[Mapping] - delete_after_upload=False, # type: bool - auto_pickle=None, # type: Optional[bool] - preview=None, # type: Any - wait_on_upload=False, # type: bool - extension_name=None, # type: Optional[str] - serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] - retries=0 # type: int + self, + name, # type: str + artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image, Any] + metadata=None, # type: Optional[Mapping] + delete_after_upload=False, # type: bool + auto_pickle=None, # type: Optional[bool] + preview=None, # type: Any + wait_on_upload=False, # type: bool + extension_name=None, # type: Optional[str] + serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] + retries=0, # type: int ): # type: (...) -> bool """ @@ -2786,9 +2894,7 @@ class Task(_Task): n_last_iterations = MAX_SERIES_PER_METRIC.get() if isinstance(n_last_iterations, int) and n_last_iterations >= 0: - samples = self._get_debug_samples( - title=title, series=series, n_last_iterations=n_last_iterations - ) + samples = self._get_debug_samples(title=title, series=series, n_last_iterations=n_last_iterations) else: raise TypeError( "Parameter n_last_iterations is expected to be a positive integer value," @@ -2816,20 +2922,17 @@ class Task(_Task): scroll_id = response.response_data.get("scroll_id", None) for metric_resp in response.response_data.get("metrics", []): - iterations_events = [iteration["events"] for iteration in metric_resp.get("iterations", [])] # type: List[List[dict]] - flattened_events = (event - for single_iter_events in iterations_events - for event in single_iter_events) + iterations_events = [ + iteration["events"] for iteration in metric_resp.get("iterations", []) + ] # type: List[List[dict]] + flattened_events = (event for single_iter_events in iterations_events for event in single_iter_events) debug_samples.extend(flattened_events) - response = self._send_debug_image_request( - title, series, n_last_iterations, scroll_id=scroll_id - ) + response = self._send_debug_image_request(title, series, n_last_iterations, scroll_id=scroll_id) - if (len(debug_samples) == n_last_iterations - or all( - len(metric_resp.get("iterations", [])) == 0 - for metric_resp in response.response_data.get("metrics", []))): + if len(debug_samples) == n_last_iterations or all( + len(metric_resp.get("iterations", [])) == 0 for metric_resp in response.response_data.get("metrics", []) + ): break return debug_samples @@ -2995,8 +3098,9 @@ class Task(_Task): scalar_metrics = dict() for i in metrics.values(): for j in i.values(): - scalar_metrics.setdefault(j['metric'], {}).setdefault( - j['variant'], {'last': j['value'], 'min': j['min_value'], 'max': j['max_value']}) + scalar_metrics.setdefault(j["metric"], {}).setdefault( + j["variant"], {"last": j["value"], "min": j["min_value"], "max": j["max_value"]} + ) return scalar_metrics def get_parameters_as_dict(self, cast=False): @@ -3043,9 +3147,9 @@ class Task(_Task): return dict(params.get(section, {})) def set_user_properties( - self, - *iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict]] - **properties # type: Union[str, dict, int, float, None] + self, + *iterables, # type: Union[Mapping[str, Union[str, dict, None]], Iterable[dict]] + **properties # type: Union[str, dict, int, float, None] ): # type: (...) -> bool """ @@ -3117,9 +3221,9 @@ class Task(_Task): return False return self._hyper_params_manager.edit_hyper_params( - iterables=list(properties.items()) + ( - list(iterables.items()) if isinstance(iterables, dict) else list(iterables)), - replace='none', + iterables=list(properties.items()) + + (list(iterables.items()) if isinstance(iterables, dict) else list(iterables)), + replace="none", force_section="properties", ) @@ -3147,17 +3251,17 @@ class Task(_Task): "working_dir": script.working_dir, "entry_point": script.entry_point, "branch": script.branch, - "repository": script.repository + "repository": script.repository, } def set_script( - self, - repository=None, # type: Optional[str] - branch=None, # type: Optional[str] - commit=None, # type: Optional[str] - diff=None, # type: Optional[str] - working_dir=None, # type: Optional[str] - entry_point=None, # type: Optional[str] + self, + repository=None, # type: Optional[str] + branch=None, # type: Optional[str] + commit=None, # type: Optional[str] + diff=None, # type: Optional[str] + working_dir=None, # type: Optional[str] + entry_point=None, # type: Optional[str] ): # type: (...) -> None """ @@ -3219,11 +3323,11 @@ class Task(_Task): return self._hyper_params_manager.delete_hyper_params(*iterables) def set_base_docker( - self, - docker_cmd=None, # type: Optional[str] - docker_image=None, # type: Optional[str] - docker_arguments=None, # type: Optional[Union[str, Sequence[str]]] - docker_setup_bash_script=None # type: Optional[Union[str, Sequence[str]]] + self, + docker_cmd=None, # type: Optional[str] + docker_image=None, # type: Optional[str] + docker_arguments=None, # type: Optional[Union[str, Sequence[str]]] + docker_setup_bash_script=None, # type: Optional[Union[str, Sequence[str]]] ): # type: (...) -> () """ @@ -3245,7 +3349,7 @@ class Task(_Task): super(Task, self).set_base_docker( docker_cmd=docker_cmd or docker_image, docker_arguments=docker_arguments, - docker_setup_bash_script=docker_setup_bash_script + docker_setup_bash_script=docker_setup_bash_script, ) @classmethod @@ -3335,7 +3439,8 @@ class Task(_Task): if not clone and not exit_process: raise ValueError( "clone==False and exit_process==False is not supported. " - "Task enqueuing itself must exit the process afterwards.") + "Task enqueuing itself must exit the process afterwards." + ) # make sure we analyze the process if self.status in (Task.TaskStatusEnum.in_progress,): @@ -3343,7 +3448,7 @@ class Task(_Task): # wait for repository detection (5 minutes should be reasonable time to detect all packages) self.flush(wait_for_uploads=True) if self._logger and not self.__is_subprocess(): - self._wait_for_repo_detection(timeout=300.) + self._wait_for_repo_detection(timeout=300.0) else: # close ourselves (it will make sure the repo is updated) self.close() @@ -3354,7 +3459,7 @@ class Task(_Task): else: task = self # check if the server supports enqueueing aborted/stopped Tasks - if Session.check_min_api_server_version('2.13'): + if Session.check_min_api_server_version("2.13"): self.mark_stopped(force=True) else: self.reset() @@ -3363,7 +3468,8 @@ class Task(_Task): if queue_name: Task.enqueue(task, queue_name=queue_name) LoggerRoot.get_base_logger().warning( - 'Switching to remote execution, output log page {}'.format(task.get_output_log_web_page())) + "Switching to remote execution, output log page {}".format(task.get_output_log_web_page()) + ) else: # Remove the development system tag system_tags = [t for t in task.get_system_tags() if t != self._development_tag] @@ -3374,7 +3480,8 @@ class Task(_Task): # leave this process. if exit_process: LoggerRoot.get_base_logger().warning( - 'ClearML Terminating local execution process - continuing execution remotely') + "ClearML Terminating local execution process - continuing execution remotely" + ) leave_process(0) return task @@ -3409,25 +3516,28 @@ class Task(_Task): raise ValueError("Only the main Task object can call create_function_task()") if not callable(func): raise ValueError("func must be callable") - if not Session.check_min_api_version('2.9'): - raise ValueError("Remote function execution is not supported, " - "please upgrade to the latest server version") + if not Session.check_min_api_version("2.9"): + raise ValueError( + "Remote function execution is not supported, " "please upgrade to the latest server version" + ) func_name = str(func_name or func.__name__).strip() if func_name in self._remote_functions_generated: - raise ValueError("Function name must be unique, a function by the name '{}' " - "was already created by this Task.".format(func_name)) + raise ValueError( + "Function name must be unique, a function by the name '{}' " + "was already created by this Task.".format(func_name) + ) - section_name = 'Function' - tag_name = 'func' - func_marker = '__func_readonly__' + section_name = "Function" + tag_name = "func" + func_marker = "__func_readonly__" # sanitize the dict, leave only basic types that we might want to override later in the UI func_params = {k: v for k, v in kwargs.items() if verify_basic_value(v)} func_params[func_marker] = func_name # do not query if we are running locally, there is no need. - task_func_marker = self.running_locally() or self.get_parameter('{}/{}'.format(section_name, func_marker)) + task_func_marker = self.running_locally() or self.get_parameter("{}/{}".format(section_name, func_marker)) # if we are running locally or if we are running remotely but we are not a forked tasks # condition explained: @@ -3435,7 +3545,7 @@ class Task(_Task): # (2) running remotely but this is not one of the forked tasks (i.e. it is missing the fork tag attribute) if self.running_locally() or not task_func_marker: self._wait_for_repo_detection(300) - task = self.clone(self, name=task_name or '{} <{}>'.format(self.name, func_name), parent=self.id) + task = self.clone(self, name=task_name or "{} <{}>".format(self.name, func_name), parent=self.id) task.set_system_tags((task.get_system_tags() or []) + [tag_name]) task.connect(func_params, name=section_name) self._remote_functions_generated[func_name] = task.id @@ -3459,10 +3569,10 @@ class Task(_Task): leave_process(0) def wait_for_status( - self, - status=(_Task.TaskStatusEnum.completed, _Task.TaskStatusEnum.stopped, _Task.TaskStatusEnum.closed), - raise_on_status=(_Task.TaskStatusEnum.failed,), - check_interval_sec=60., + self, + status=(_Task.TaskStatusEnum.completed, _Task.TaskStatusEnum.stopped, _Task.TaskStatusEnum.closed), + raise_on_status=(_Task.TaskStatusEnum.failed,), + check_interval_sec=60.0, ): # type: (Iterable[Task.TaskStatusEnum], Optional[Iterable[Task.TaskStatusEnum]], float) -> () """ @@ -3497,15 +3607,15 @@ class Task(_Task): """ self.reload() export_data = self.data.to_dict() - export_data.pop('last_metrics', None) - export_data.pop('last_iteration', None) - export_data.pop('status_changed', None) - export_data.pop('status_reason', None) - export_data.pop('status_message', None) - export_data.get('execution', {}).pop('artifacts', None) - export_data.get('execution', {}).pop('model', None) - export_data['project_name'] = self.get_project_name() - export_data['session_api_version'] = self.session.api_version + export_data.pop("last_metrics", None) + export_data.pop("last_iteration", None) + export_data.pop("status_changed", None) + export_data.pop("status_reason", None) + export_data.pop("status_message", None) + export_data.get("execution", {}).pop("artifacts", None) + export_data.get("execution", {}).pop("model", None) + export_data["project_name"] = self.get_project_name() + export_data["session_api_version"] = self.session.api_version return export_data def update_task(self, task_data): @@ -3553,9 +3663,9 @@ class Task(_Task): return result def register_abort_callback( - self, - callback_function, # type: Optional[Callable] - callback_execution_timeout=30. # type: float + self, + callback_function, # type: Optional[Callable] + callback_execution_timeout=30.0, # type: float ): # type (...) -> None """ Register a Task abort callback (single callback function support only). @@ -3582,7 +3692,8 @@ class Task(_Task): if float(callback_execution_timeout) <= 0: raise ValueError( - "function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout)) + "function_timeout_sec must be positive timeout in seconds, got {}".format(callback_execution_timeout) + ) # if we are running remotely we might not have a DevWorker monitoring us, so let's create one if not self._dev_worker: @@ -3591,9 +3702,7 @@ class Task(_Task): poll_freq = 15.0 self._dev_worker.register_abort_callback( - callback_function=callback_function, - execution_timeout=callback_execution_timeout, - poll_freq=poll_freq + callback_function=callback_function, execution_timeout=callback_execution_timeout, poll_freq=poll_freq ) @classmethod @@ -3610,26 +3719,27 @@ class Task(_Task): """ # restore original API version (otherwise, we might not be able to restore the data correctly) - force_api_version = task_data.get('session_api_version') or None + force_api_version = task_data.get("session_api_version") or None original_api_version = Session.api_version original_force_max_api_version = Session.force_max_api_version if force_api_version: Session.force_max_api_version = str(force_api_version) if not target_task: - project_name = task_data.get('project_name') or Task._get_project_name(task_data.get('project', '')) - target_task = Task.create(project_name=project_name, task_name=task_data.get('name', None)) + project_name = task_data.get("project_name") or Task._get_project_name(task_data.get("project", "")) + target_task = Task.create(project_name=project_name, task_name=task_data.get("name", None)) elif isinstance(target_task, six.string_types): target_task = Task.get_task(task_id=target_task) # type: Optional[Task] elif not isinstance(target_task, Task): raise ValueError( "`target_task` must be either Task id (str) or Task object, " - "received `target_task` type {}".format(type(target_task))) + "received `target_task` type {}".format(type(target_task)) + ) target_task.reload() cur_data = target_task.data.to_dict() cur_data = merge_dicts(cur_data, task_data) if update else dict(**task_data) - cur_data.pop('id', None) - cur_data.pop('project', None) + cur_data.pop("id", None) + cur_data.pop("project", None) # noinspection PyProtectedMember valid_fields = list(tasks.EditRequest._get_data_props().keys()) cur_data = dict((k, cur_data[k]) for k in valid_fields if k in cur_data) @@ -3693,11 +3803,7 @@ class Task(_Task): """ if running_remotely() or bool(offline_mode) == InterfaceBase._offline_mode: return - if ( - cls.current_task() - and cls.current_task().status != cls.TaskStatusEnum.closed - and not offline_mode - ): + if cls.current_task() and cls.current_task().status != cls.TaskStatusEnum.closed and not offline_mode: raise UsageError( "Switching from offline mode to online mode, but the current task has not been closed. Use `Task.close` to close it." ) @@ -3733,12 +3839,12 @@ class Task(_Task): :return: Newly created task ID or the ID of the continued task (previous_task_id) """ - print('ClearML: Importing offline session from {}'.format(session_folder_zip)) + print("ClearML: Importing offline session from {}".format(session_folder_zip)) temp_folder = None if Path(session_folder_zip).is_file(): # unzip the file: - temp_folder = mkdtemp(prefix='clearml-offline-') + temp_folder = mkdtemp(prefix="clearml-offline-") ZipFile(session_folder_zip).extractall(path=temp_folder) session_folder_zip = temp_folder @@ -3747,11 +3853,12 @@ class Task(_Task): raise ValueError("Could not find the session folder / zip-file {}".format(session_folder)) try: - with open((session_folder / cls._offline_filename).as_posix(), 'rt') as f: + with open((session_folder / cls._offline_filename).as_posix(), "rt") as f: export_data = json.load(f) except Exception as ex: raise ValueError( - "Could not read Task object {}: Exception {}".format(session_folder / cls._offline_filename, ex)) + "Could not read Task object {}: Exception {}".format(session_folder / cls._offline_filename, ex) + ) current_task = cls.import_task(export_data) if previous_task_id: task_holding_reports = cls.get_task(task_id=previous_task_id) @@ -3763,21 +3870,23 @@ class Task(_Task): # fix artifacts if current_task.data.execution.artifacts: from . import StorageManager + # noinspection PyProtectedMember - offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/') + offline_folder = os.path.join(export_data.get("offline_folder", ""), "data/") # noinspection PyProtectedMember remote_url = current_task._get_default_report_storage_uri() - if remote_url and remote_url.endswith('/'): + if remote_url and remote_url.endswith("/"): remote_url = remote_url[:-1] for artifact in current_task.data.execution.artifacts: - local_path = artifact.uri.replace(offline_folder, '', 1) - local_file = session_folder / 'data' / local_path + local_path = artifact.uri.replace(offline_folder, "", 1) + local_file = session_folder / "data" / local_path if local_file.is_file(): remote_path = local_path.replace( - '.{}{}'.format(export_data['id'], os.sep), '.{}{}'.format(current_task.id, os.sep), 1) - artifact.uri = '{}/{}'.format(remote_url, remote_path) + ".{}{}".format(export_data["id"], os.sep), ".{}{}".format(current_task.id, os.sep), 1 + ) + artifact.uri = "{}/{}".format(remote_url, remote_path) StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri) # noinspection PyProtectedMember task_holding_reports._edit(execution=current_task.data.execution) @@ -3792,7 +3901,7 @@ class Task(_Task): iteration_offset=iteration_offset, remote_url=task_holding_reports._get_default_report_storage_uri(), only_with_id=output_model["id"], - session=task_holding_reports.session + session=task_holding_reports.session, ) # logs TaskHandler.report_offline_session(task_holding_reports, session_folder, iteration_offset=iteration_offset) @@ -3805,7 +3914,7 @@ class Task(_Task): session=task_holding_reports.session, ) # print imported results page - print('ClearML results page: {}'.format(task_holding_reports.get_output_log_web_page())) + print("ClearML results page: {}".format(task_holding_reports.get_output_log_web_page())) task_holding_reports.mark_completed() # close task task_holding_reports.close() @@ -3822,13 +3931,7 @@ class Task(_Task): @classmethod def set_credentials( - cls, - api_host=None, - web_host=None, - files_host=None, - key=None, - secret=None, - store_conf_file=False + cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, store_conf_file=False ): # type: (Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], bool) -> None """ @@ -3884,17 +3987,20 @@ class Task(_Task): active_conf_file = get_active_config_file() if active_conf_file: getLogger().warning( - 'Could not store credentials in configuration file, ' - '\'{}\' already exists'.format(active_conf_file)) + "Could not store credentials in configuration file, " "'{}' already exists".format(active_conf_file) + ) else: - conf = {'api': dict( - api_server=Session.default_host, - web_server=Session.default_web, - files_server=Session.default_files, - credentials=dict(access_key=Session.default_key, secret_key=Session.default_secret))} - with open(get_config_file(), 'wt') as f: - lines = json.dumps(conf, indent=4).split('\n') - f.write('\n'.join(lines[1:-1])) + conf = { + "api": dict( + api_server=Session.default_host, + web_server=Session.default_web, + files_server=Session.default_files, + credentials=dict(access_key=Session.default_key, secret_key=Session.default_secret), + ) + } + with open(get_config_file(), "wt") as f: + lines = json.dumps(conf, indent=4).split("\n") + f.write("\n".join(lines[1:-1])) @classmethod def debug_simulate_remote_task(cls, task_id, reset_task=False): @@ -3925,6 +4031,7 @@ class Task(_Task): from .config.remote import override_current_task_id from .config.defs import LOG_TO_BACKEND_ENV_VAR + override_current_task_id(task_id) LOG_TO_BACKEND_ENV_VAR.set(True) DEBUG_SIMULATE_REMOTE_TASK.set(True) @@ -3943,10 +4050,7 @@ class Task(_Task): if not return_name or not queue_id: return queue_id try: - queue_name_result = Task._send( - Task._get_default_session(), - queues.GetByIdRequest(queue_id) - ) + queue_name_result = Task._send(Task._get_default_session(), queues.GetByIdRequest(queue_id)) return queue_name_result.response.queue.name except Exception as e: getLogger().warning("Could not get name of queue with ID '{}': {}".format(queue_id, e)) @@ -3970,8 +4074,10 @@ class Task(_Task): """ if not project_name: if not cls.__main_task: - raise ValueError("Please provide project_name, no global task context found " - "(Task.current_task hasn't been called)") + raise ValueError( + "Please provide project_name, no global task context found " + "(Task.current_task hasn't been called)" + ) project_name = cls.__main_task.get_project_name() try: @@ -4047,8 +4153,15 @@ class Task(_Task): @classmethod def _create_dev_task( - cls, default_project_name, default_task_name, default_task_type, tags, - reuse_last_task_id, continue_last_task=False, detect_repo=True, auto_connect_streams=True + cls, + default_project_name, + default_task_name, + default_task_type, + tags, + reuse_last_task_id, + continue_last_task=False, + detect_repo=True, + auto_connect_streams=True, ): if not default_project_name or not default_task_name: # get project name and task name from repository name and entry_point @@ -4056,14 +4169,14 @@ class Task(_Task): if not default_project_name: # noinspection PyBroadException try: - parts = result.script['repository'].split('/') - default_project_name = (parts[-1] or parts[-2]).replace('.git', '') or 'Untitled' + parts = result.script["repository"].split("/") + default_project_name = (parts[-1] or parts[-2]).replace(".git", "") or "Untitled" except Exception: - default_project_name = 'Untitled' + default_project_name = "Untitled" if not default_task_name: # noinspection PyBroadException try: - default_task_name = os.path.splitext(os.path.basename(result.script['entry_point']))[0] + default_task_name = os.path.splitext(os.path.basename(result.script["entry_point"]))[0] except Exception: pass @@ -4100,7 +4213,7 @@ class Task(_Task): elif not reuse_last_task_id or not cls.__task_is_relevant(default_task): default_task_id = None else: - default_task_id = default_task.get('id') if default_task else None + default_task_id = default_task.get("id") if default_task else None if default_task_id: try: @@ -4111,9 +4224,10 @@ class Task(_Task): ) # instead of resting the previously used task we are continuing the training with it. - if task and \ - (continue_last_task or - (isinstance(continue_last_task, int) and not isinstance(continue_last_task, bool))): + if task and ( + continue_last_task + or (isinstance(continue_last_task, int) and not isinstance(continue_last_task, bool)) + ): task.reload() task.mark_started(force=True) # allow to disable the @@ -4123,14 +4237,17 @@ class Task(_Task): task.set_initial_iteration(continue_last_task) else: - task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags - task_artifacts = task.data.execution.artifacts \ - if hasattr(task.data.execution, 'artifacts') else None - if ((task._status in ( - cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed)) - or task.output_models_id or (cls.archived_tag in task_tags) - or (cls._development_tag not in task_tags) - or task_artifacts): + task_tags = task.data.system_tags if hasattr(task.data, "system_tags") else task.data.tags + task_artifacts = ( + task.data.execution.artifacts if hasattr(task.data.execution, "artifacts") else None + ) + if ( + (task._status in (cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed)) + or task.output_models_id + or (cls.archived_tag in task_tags) + or (cls._development_tag not in task_tags) + or task_artifacts + ): # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode # If the task is archived, or already has an output model, # we shouldn't use it in development mode either @@ -4146,7 +4263,8 @@ class Task(_Task): # clear the heaviest stuff first task._clear_task( system_tags=[cls._development_tag], - comment=make_message('Auto-generated at %(time)s by %(user)s@%(host)s')) + comment=make_message("Auto-generated at %(time)s by %(user)s@%(host)s"), + ) except (Exception, ValueError): # we failed reusing task, create a new one @@ -4188,26 +4306,28 @@ class Task(_Task): # force update of base logger to this current task (this is the main logger task) logger = task._get_logger(auto_connect_streams=auto_connect_streams) if closed_old_task: - logger.report_text('ClearML Task: Closing old development task id={}'.format(default_task.get('id'))) + logger.report_text("ClearML Task: Closing old development task id={}".format(default_task.get("id"))) # print warning, reusing/creating a task if default_task_id and not continue_last_task: - logger.report_text('ClearML Task: overwriting (reusing) task id=%s' % task.id) + logger.report_text("ClearML Task: overwriting (reusing) task id=%s" % task.id) elif default_task_id and continue_last_task: - logger.report_text('ClearML Task: continuing previous task id=%s ' - 'Notice this run will not be reproducible!' % task.id) + logger.report_text( + "ClearML Task: continuing previous task id=%s " "Notice this run will not be reproducible!" % task.id + ) else: - logger.report_text('ClearML Task: created new task id=%s' % task.id) + logger.report_text("ClearML Task: created new task id=%s" % task.id) # update current repository and put warning into logs if detect_repo: # noinspection PyBroadException try: import traceback + stack = traceback.extract_stack(limit=10) # NOTICE WE ARE ALWAYS 3 down from caller in stack! for i in range(len(stack) - 1, 0, -1): # look for the Task.init call, then the one above it is the callee module - if stack[i].name == 'init': + if stack[i].name == "init": task._calling_filename = os.path.abspath(stack[i - 1].filename) break except Exception: @@ -4241,16 +4361,19 @@ class Task(_Task): if not self._logger: # do not recreate logger after task was closed/quit - if self._at_exit_called and self._at_exit_called in (True, get_current_thread_id(),): + if self._at_exit_called and self._at_exit_called in ( + True, + get_current_thread_id(), + ): raise ValueError("Cannot use Task Logger after task was closed") # Get a logger object self._logger = Logger( private_task=self, - connect_stdout=(auto_connect_streams is True) or - (isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stdout', False)), - connect_stderr=(auto_connect_streams is True) or - (isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stderr', False)), - connect_logging=isinstance(auto_connect_streams, dict) and auto_connect_streams.get('logging', False), + connect_stdout=(auto_connect_streams is True) + or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stdout", False)), + connect_stderr=(auto_connect_streams is True) + or (isinstance(auto_connect_streams, dict) and auto_connect_streams.get("stderr", False)), + connect_logging=isinstance(auto_connect_streams, dict) and auto_connect_streams.get("logging", False), ) # make sure we set our reported to async mode # we make sure we flush it in self._at_exit @@ -4288,7 +4411,7 @@ class Task(_Task): True, description="If True, ignore UI/backend overrides when running remotely." " Set it to False if you would like the overrides to be applied", - value_type=bool + value_type=bool, ) elif not self.running_locally(): ignore_remote_overrides = self.get_parameter(overrides_name, default=ignore_remote_overrides, cast=True) @@ -4314,11 +4437,11 @@ class Task(_Task): # add into comment that we are using this model # refresh comment - comment = self._reload_field("comment") or self.comment or '' + comment = self._reload_field("comment") or self.comment or "" - if not comment.endswith('\n'): - comment += '\n' - comment += 'Using model id: {}'.format(model.id) + if not comment.endswith("\n"): + comment += "\n" + comment += "Using model id: {}".format(model.id) self.set_comment(comment) model.connect(self, name, ignore_remote_overrides=ignore_remote_overrides) @@ -4361,8 +4484,7 @@ class Task(_Task): if running_remotely() and (self.is_main_task() or self._is_remote_main_task()) and not ignore_remote_overrides: self._arguments.copy_to_parser(parser, parsed_args) else: - self._arguments.copy_defaults_from_argparse( - parser, args=args, namespace=namespace, parsed_args=parsed_args) + self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args) return parser def _connect_dictionary(self, dictionary, name=None, ignore_remote_overrides=False): @@ -4395,7 +4517,11 @@ class Task(_Task): if isinstance(v, dict): _check_keys(v, warning_sent) - if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()) or ignore_remote_overrides: + if ( + not running_remotely() + or not (self.is_main_task() or self._is_remote_main_task()) + or ignore_remote_overrides + ): _check_keys(dictionary) flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()} self._arguments.copy_from_dict(flat_dict, prefix=name) @@ -4418,7 +4544,7 @@ class Task(_Task): parameters = self.get_parameters(cast=True) if name: parameters = dict( - (k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith("{}/".format(name)) + (k[len(name) + 1 :], v) for k, v in parameters.items() if k.startswith("{}/".format(name)) ) parameters.pop(ignore_remote_overrides_section, None) attr_class.update_from_dict(parameters) @@ -4431,7 +4557,7 @@ class Task(_Task): def _connect_object(self, an_object, name=None, ignore_remote_overrides=False): def verify_type(key, value): - if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types): + if str(key).startswith("_") or not isinstance(value, self._parameters_allowed_types): return False # verify everything is json able (i.e. basic types) try: @@ -4462,16 +4588,12 @@ class Task(_Task): if self._at_exit_called: return - self.log.warning( - "### TASK STOPPED - USER ABORTED - {} ###".format( - stop_reason.upper().replace('_', ' ') - ) - ) + self.log.warning("### TASK STOPPED - USER ABORTED - {} ###".format(stop_reason.upper().replace("_", " "))) self.flush(wait_for_uploads=True) # if running remotely, we want the daemon to kill us if self.running_locally(): - self.stopped(status_reason='USER ABORTED') + self.stopped(status_reason="USER ABORTED") if self._dev_worker: self._dev_worker.unregister() @@ -4520,8 +4642,12 @@ class Task(_Task): kill_ourselves.terminate() def _dev_mode_setup_worker(self): - if (running_remotely() and not DEBUG_SIMULATE_REMOTE_TASK.get()) \ - or not self.is_main_task() or self._at_exit_called or self._offline_mode: + if ( + (running_remotely() and not DEBUG_SIMULATE_REMOTE_TASK.get()) + or not self.is_main_task() + or self._at_exit_called + or self._offline_mode + ): return if self._dev_worker: @@ -4548,19 +4674,22 @@ class Task(_Task): # if negative timeout, just kill the thread: if timeout is not None and timeout < 0: from .utilities.lowlevel.threads import kill_thread + kill_thread(self._detect_repo_async_thread) else: - self.log.info('Waiting for repository detection and full package requirement analysis') + self.log.info("Waiting for repository detection and full package requirement analysis") self._detect_repo_async_thread.join(timeout=timeout) # because join has no return value if self._detect_repo_async_thread.is_alive(): - self.log.info('Repository and package analysis timed out ({} sec), ' - 'giving up'.format(timeout)) + self.log.info( + "Repository and package analysis timed out ({} sec), " "giving up".format(timeout) + ) # done waiting, kill the thread from .utilities.lowlevel.threads import kill_thread + kill_thread(self._detect_repo_async_thread) else: - self.log.info('Finished repository detection and package analysis') + self.log.info("Finished repository detection and package analysis") self._detect_repo_async_thread = None except Exception: pass @@ -4635,14 +4764,17 @@ class Task(_Task): # first thing mark task as stopped, so we will not end up with "running" on lost tasks # if we are running remotely, the daemon will take care of it wait_for_std_log = True - if (not running_remotely() or DEBUG_SIMULATE_REMOTE_TASK.get()) \ - and self.is_main_task() and not is_sub_process: + if ( + (not running_remotely() or DEBUG_SIMULATE_REMOTE_TASK.get()) + and self.is_main_task() + and not is_sub_process + ): # check if we crashed, ot the signal is not interrupt (manual break) - task_status = ('stopped',) + task_status = ("stopped",) if self.__exit_hook: is_exception = self.__exit_hook.exception # check if we are running inside a debugger - if not is_exception and sys.modules.get('pydevd'): + if not is_exception and sys.modules.get("pydevd"): # noinspection PyBroadException try: is_exception = sys.last_type @@ -4650,26 +4782,35 @@ class Task(_Task): pass # check if this is Jupyter interactive session, do not mark as exception - if 'IPython' in sys.modules: + if "IPython" in sys.modules: is_exception = None # only if we have an exception (and not ctrl-break) or signal is not SIGTERM / SIGINT - if (is_exception and not isinstance(is_exception, KeyboardInterrupt) - and is_exception != KeyboardInterrupt) \ - or (not self.__exit_hook.remote_user_aborted and - (self.__exit_hook.signal not in (None, 2, 15) or self.__exit_hook.exit_code)): + if ( + is_exception + and not isinstance(is_exception, KeyboardInterrupt) + and is_exception != KeyboardInterrupt + ) or ( + not self.__exit_hook.remote_user_aborted + and (self.__exit_hook.signal not in (None, 2, 15) or self.__exit_hook.exit_code) + ): task_status = ( - 'failed', - 'Exception {}'.format(is_exception) if is_exception else - 'Signal {}'.format(self.__exit_hook.signal)) + "failed", + "Exception {}".format(is_exception) + if is_exception + else "Signal {}".format(self.__exit_hook.signal), + ) wait_for_uploads = False else: - wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None) - if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None and \ - not is_exception: - task_status = ('completed',) + wait_for_uploads = self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None + if ( + not self.__exit_hook.remote_user_aborted + and self.__exit_hook.signal is None + and not is_exception + ): + task_status = ("completed",) else: - task_status = ('stopped',) + task_status = ("stopped",) # user aborted. do not bother flushing the stdout logs wait_for_std_log = self.__exit_hook.signal is not None @@ -4679,7 +4820,7 @@ class Task(_Task): self._summary_artifacts() # make sure that if we crashed the thread we are not waiting forever if not is_sub_process: - self._wait_for_repo_detection(timeout=10.) + self._wait_for_repo_detection(timeout=10.0) # kill the repo thread (negative timeout, do not wait), if it hasn't finished yet. if not is_sub_process: @@ -4687,9 +4828,10 @@ class Task(_Task): # wait for uploads print_done_waiting = False - if wait_for_uploads and (BackendModel.get_num_results() > 0 or - (self.__reporter and self.__reporter.events_waiting())): - self.log.info('Waiting to finish uploads') + if wait_for_uploads and ( + BackendModel.get_num_results() > 0 or (self.__reporter and self.__reporter.events_waiting()) + ): + self.log.info("Waiting to finish uploads") print_done_waiting = True # from here, do not send log in background thread if wait_for_uploads: @@ -4706,12 +4848,13 @@ class Task(_Task): # noinspection PyBroadException try: from .storage.helper import StorageHelper + StorageHelper.close_async_threads() except Exception: pass if print_done_waiting: - self.log.info('Finished uploading') + self.log.info("Finished uploading") # elif self._logger: # # noinspection PyProtectedMember # self._logger._flush_stdout_handler() @@ -4735,12 +4878,12 @@ class Task(_Task): # change task status if not task_status: pass - elif task_status[0] == 'failed': + elif task_status[0] == "failed": self.mark_failed(status_reason=task_status[1]) - elif task_status[0] == 'completed': + elif task_status[0] == "completed": self.set_progress(100) self.mark_completed() - elif task_status[0] == 'stopped': + elif task_status[0] == "stopped": self.stopped() # this is so in theory we can close a main task and start a new one @@ -4760,13 +4903,13 @@ class Task(_Task): self._edit() # create zip file offline_folder = self.get_offline_mode_folder() - zip_file = offline_folder.as_posix() + '.zip' - with ZipFile(zip_file, 'w', allowZip64=True, compression=ZIP_DEFLATED) as zf: - for filename in offline_folder.rglob('*'): + zip_file = offline_folder.as_posix() + ".zip" + with ZipFile(zip_file, "w", allowZip64=True, compression=ZIP_DEFLATED) as zf: + for filename in offline_folder.rglob("*"): if filename.is_file(): relative_file_name = filename.relative_to(offline_folder).as_posix() zf.write(filename.as_posix(), arcname=relative_file_name) - print('ClearML Task: Offline session stored in {}'.format(zip_file)) + print("ClearML Task: Offline session stored in {}".format(zip_file)) except Exception: pass @@ -4811,13 +4954,13 @@ class Task(_Task): @classmethod def __get_task( - cls, - task_id=None, # type: Optional[str] - project_name=None, # type: Optional[str] - task_name=None, # type: Optional[str] - include_archived=True, # type: bool - tags=None, # type: Optional[Sequence[str]] - task_filter=None # type: Optional[dict] + cls, + task_id=None, # type: Optional[str] + project_name=None, # type: Optional[str] + task_name=None, # type: Optional[str] + include_archived=True, # type: bool + tags=None, # type: Optional[Sequence[str]] + task_filter=None, # type: Optional[dict] ): # type: (...) -> TaskInstance @@ -4825,49 +4968,57 @@ class Task(_Task): return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) if project_name: - res = cls._send( - cls._get_default_session(), - projects.GetAllRequest( - name=exact_match_regex(project_name) - ) - ) - project = get_single_result(entity='project', query=project_name, results=res.response.projects) + res = cls._send(cls._get_default_session(), projects.GetAllRequest(name=exact_match_regex(project_name))) + project = get_single_result(entity="project", query=project_name, results=res.response.projects) else: project = None # get default session, before trying to access tasks.Task so that we do not create two sessions. session = cls._get_default_session() - system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' + system_tags = "system_tags" if hasattr(tasks.Task, "system_tags") else "tags" task_filter = task_filter or {} if not include_archived: - task_filter['system_tags'] = (task_filter.get('system_tags') or []) + ['-{}'.format(cls.archived_tag)] + task_filter["system_tags"] = (task_filter.get("system_tags") or []) + ["-{}".format(cls.archived_tag)] if tags: - task_filter['tags'] = (task_filter.get('tags') or []) + list(tags) + task_filter["tags"] = (task_filter.get("tags") or []) + list(tags) res = cls._send( session, tasks.GetAllRequest( project=[project.id] if project else None, name=exact_match_regex(task_name) if task_name else None, - only_fields=['id', 'name', 'last_update', system_tags], + only_fields=["id", "name", "last_update", system_tags], **task_filter - ) + ), ) res_tasks = res.response.tasks # if we have more than one result, filter out the 'archived' results # notice that if we only have one result we do get the archived one as well. if len(res_tasks) > 1: - filtered_tasks = [t for t in res_tasks if not getattr(t, system_tags, None) or - cls.archived_tag not in getattr(t, system_tags, None)] + filtered_tasks = [ + t + for t in res_tasks + if not getattr(t, system_tags, None) or cls.archived_tag not in getattr(t, system_tags, None) + ] # if we did not filter everything (otherwise we have only archived tasks, so we return them) if filtered_tasks: res_tasks = filtered_tasks task = get_single_result( - entity='task', - query={k: v for k, v in dict( - project_name=project_name, task_name=task_name, tags=tags, - include_archived=include_archived, task_filter=task_filter).items() if v}, - results=res_tasks, raise_on_error=False) + entity="task", + query={ + k: v + for k, v in dict( + project_name=project_name, + task_name=task_name, + tags=tags, + include_archived=include_archived, + task_filter=task_filter, + ).items() + if v + }, + results=res_tasks, + raise_on_error=False, + ) if not task: # should never happen return None # noqa @@ -4936,9 +5087,8 @@ class Task(_Task): res = cls._send( cls._get_default_session(), projects.GetAllRequest( - name=exact_match_regex(name) if exact_match_regex_flag else name, - **aux_kwargs - ) + name=exact_match_regex(name) if exact_match_regex_flag else name, **aux_kwargs + ), ) if res.response and res.response.projects: project_ids.extend([project.id for project in res.response.projects]) @@ -4954,14 +5104,14 @@ class Task(_Task): return [] session = cls._get_default_session() - system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' - only_fields = ['id', 'name', 'last_update', system_tags] + system_tags = "system_tags" if hasattr(tasks.Task, "system_tags") else "tags" + only_fields = ["id", "name", "last_update", system_tags] - if kwargs and kwargs.get('only_fields'): - only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields)) + if kwargs and kwargs.get("only_fields"): + only_fields = list(set(kwargs.pop("only_fields")) | set(only_fields)) # if we have specific page to look for, we should only get the requested one - if not fetch_only_first_page and kwargs and 'page' in kwargs: + if not fetch_only_first_page and kwargs and "page" in kwargs: fetch_only_first_page = True ret_tasks = [] @@ -5011,8 +5161,7 @@ class Task(_Task): @classmethod def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type): - hash_key = cls.__get_hash_key( - cls._get_api_server(), default_project_name, default_task_name, default_task_type) + hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) # check if we have a cached task_id we can reuse # it must be from within the last 24h and with the same project/name/type @@ -5023,12 +5172,12 @@ class Task(_Task): return None try: - task_data['type'] = cls.TaskTypes(task_data['type']) + task_data["type"] = cls.TaskTypes(task_data["type"]) except (ValueError, KeyError): LoggerRoot.get_base_logger().warning( "Corrupted session cache entry: {}. " "Unsupported task type: {}" - "Creating a new task.".format(hash_key, task_data['type']), + "Creating a new task.".format(hash_key, task_data["type"]), ) return None @@ -5037,19 +5186,22 @@ class Task(_Task): @classmethod def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id): - hash_key = cls.__get_hash_key( - cls._get_api_server(), default_project_name, default_task_name, default_task_type) + hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) task_id = str(task_id) # update task session cache task_sessions = SessionCache.load_dict(str(cls)) - last_task_session = {'time': time.time(), 'project': default_project_name, 'name': default_task_name, - 'type': default_task_type, 'id': task_id} + last_task_session = { + "time": time.time(), + "project": default_project_name, + "name": default_task_name, + "type": default_task_type, + "id": task_id, + } # remove stale sessions for k in list(task_sessions.keys()): - if ((time.time() - task_sessions[k].get('time', 0)) > - 60 * 60 * cls.__task_id_reuse_time_window_in_hours): + if (time.time() - task_sessions[k].get("time", 0)) > 60 * 60 * cls.__task_id_reuse_time_window_in_hours: task_sessions.pop(k) # update current session task_sessions[hash_key] = last_task_session @@ -5058,11 +5210,12 @@ class Task(_Task): @classmethod def __task_timed_out(cls, task_data): - return \ - task_data and \ - task_data.get('id') and \ - task_data.get('time') and \ - (time.time() - task_data.get('time')) > (60 * 60 * cls.__task_id_reuse_time_window_in_hours) + return ( + task_data + and task_data.get("id") + and task_data.get("time") + and (time.time() - task_data.get("time")) > (60 * 60 * cls.__task_id_reuse_time_window_in_hours) + ) @classmethod def __get_task_api_obj(cls, task_id, only_fields=None): @@ -5101,14 +5254,14 @@ class Task(_Task): if cls.__task_timed_out(task_data): return False - task_id = task_data.get('id') + task_id = task_data.get("id") if not task_id: return False # noinspection PyBroadException try: - task = cls.__get_task_api_obj(task_id, ('id', 'name', 'project', 'type')) + task = cls.__get_task_api_obj(task_id, ("id", "name", "project", "type")) except Exception: task = None @@ -5120,8 +5273,7 @@ class Task(_Task): # noinspection PyBroadException try: project = cls._send( - cls._get_default_session(), - projects.GetByIdRequest(project=task.project) + cls._get_default_session(), projects.GetByIdRequest(project=task.project) ).response.project if project: @@ -5129,31 +5281,37 @@ class Task(_Task): except Exception: pass - if task_data.get('type') and \ - task_data.get('type') not in (cls.TaskTypes.training, cls.TaskTypes.testing) and \ - not Session.check_min_api_version(2.8): - print('WARNING: Changing task type to "{}" : ' - 'clearml-server does not support task type "{}", ' - 'please upgrade clearml-server.'.format(cls.TaskTypes.training, task_data['type'].value)) - task_data['type'] = cls.TaskTypes.training + if ( + task_data.get("type") + and task_data.get("type") not in (cls.TaskTypes.training, cls.TaskTypes.testing) + and not Session.check_min_api_version(2.8) + ): + print( + 'WARNING: Changing task type to "{}" : ' + 'clearml-server does not support task type "{}", ' + "please upgrade clearml-server.".format(cls.TaskTypes.training, task_data["type"].value) + ) + task_data["type"] = cls.TaskTypes.training compares = ( - (task.name, 'name'), - (project_name, 'project'), - (task.type, 'type'), + (task.name, "name"), + (project_name, "project"), + (task.type, "type"), ) # compare after casting to string to avoid enum instance issues # remember we might have replaced the api version by now, so enums are different - return all(six.text_type(server_data) == six.text_type(task_data.get(task_data_key)) - for server_data, task_data_key in compares) + return all( + six.text_type(server_data) == six.text_type(task_data.get(task_data_key)) + for server_data, task_data_key in compares + ) @classmethod def __close_timed_out_task(cls, task_data): if not task_data: return False - task = cls.__get_task_api_obj(task_data.get('id'), ('id', 'status')) + task = cls.__get_task_api_obj(task_data.get("id"), ("id", "status")) if task is None: return False @@ -5170,11 +5328,7 @@ class Task(_Task): if task.status not in stopped_statuses: cls._send( cls._get_default_session(), - tasks.StoppedRequest( - task=task.id, - force=True, - status_message="Stopped timed out development task" - ), + tasks.StoppedRequest(task=task.id, force=True, status_message="Stopped timed out development task"), ) return True @@ -5207,17 +5361,18 @@ class Task(_Task): def __getstate__(self): # type: () -> dict - return {'main': self.is_main_task(), 'id': self.id, 'offline': self.is_offline()} + return {"main": self.is_main_task(), "id": self.id, "offline": self.is_offline()} def __setstate__(self, state): - if state['main'] and not self.__main_task: + if state["main"] and not self.__main_task: Task.__forked_proc_main_pid = None - Task.__update_master_pid_task(task=state['id']) - if state['offline']: - Task.set_offline(offline_mode=state['offline']) + Task.__update_master_pid_task(task=state["id"]) + if state["offline"]: + Task.set_offline(offline_mode=state["offline"]) - task = Task.init( - continue_last_task=state['id'], - auto_connect_frameworks={'detect_repository': False}) \ - if state['main'] else Task.get_task(task_id=state['id']) + task = ( + Task.init(continue_last_task=state["id"], auto_connect_frameworks={"detect_repository": False}) + if state["main"] + else Task.get_task(task_id=state["id"]) + ) self.__dict__ = task.__dict__