diff --git a/trains/task.py b/trains/task.py index 353913ec..3ee0ba3e 100644 --- a/trains/task.py +++ b/trains/task.py @@ -17,7 +17,7 @@ try: except ImportError: from collections import Sequence as CollectionsSequence -from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING +from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable import psutil import six @@ -162,6 +162,7 @@ class Task(_Task): self._detect_repo_async_thread = None self._resource_monitor = None self._calling_filename = None + self._remote_functions_generated = {} # register atexit, so that we mark the task as stopped self._at_exit_called = False @@ -937,7 +938,9 @@ class Task(_Task): - argparse - An argparse object for parameters. - dict - A dictionary for parameters. - TaskParameters - A TaskParameters object. - - model - A model object for initial model warmup, or for model update/snapshot uploading. + - Model - A model object for initial model warmup, or for model update/snapshot uploading. + - Class type - A Class type, storing all class properties (excluding '_' prefix properties) + - Object - A class instance, storing all instance properties (excluding '_' prefix properties) :param str name: A section name associated with the connected object. Default: 'General' Currently only supported for `dict` / `TaskParameter` objects @@ -949,13 +952,15 @@ class Task(_Task): :raise: Raise an exception on unsupported objects. """ - + # dispatching by match order dispatch = ( (OutputModel, self._connect_output_model), (InputModel, self._connect_input_model), (ArgumentParser, self._connect_argparse), (dict, self._connect_dictionary), (TaskParameters, self._connect_task_parameters), + (type, self._connect_object), + (object, self._connect_object), ) multi_config_support = Session.check_min_api_version('2.9') @@ -1657,7 +1662,7 @@ class Task(_Task): return True def execute_remotely(self, queue_name=None, clone=False, exit_process=True): - # type: (Optional[str], bool, bool) -> () + # type: (Optional[str], bool, bool) -> Optional[Task] """ If task is running locally (i.e., not by ``trains-agent``), then clone the Task and enqueue it for remote execution; or, stop the execution of the current Task, reset its state, and enqueue it. If ``exit==True``, @@ -1684,10 +1689,12 @@ class Task(_Task): .. warning:: If ``clone==False``, then ``exit_process`` must be ``True``. + + :return Task: return the task object of the newly generated remotely excuting task """ # do nothing, we are running remotely - if running_remotely(): - return + if running_remotely() and self.is_main_task(): + return None if not clone and not exit_process: raise ValueError( @@ -1727,7 +1734,86 @@ class Task(_Task): LoggerRoot.get_base_logger().warning('Terminating local execution process') exit(0) - return + return task + + def create_function_task(self, func, func_name=None, task_name=None, **kwargs): + # type: (Callable, Optional[str], Optional[str], **Optional[Any]) -> Optional[Task] + """ + Create a new task, and call ``func`` with the specified kwargs. + One can think of this call as remote forking, where the newly created instance is the new Task + calling the specified func with the appropriate kwargs and leave once the func terminates. + Notice that a remote executed function cannot create another child remote executed function. + + .. note:: + - Must be called from the main Task, i.e. the one created by Task.init(...) + - The remote Tasks inherits the environment from the creating Task + - In the remote Task, the entrypoint is the same as the creating Task + - In the remote Task, the execution is the same until reaching this function call + + :param func: A function to execute remotely as a single Task. + On the remote executed Task the entry-point and the environment are copied from this + calling process, only this function call redirect the the execution flow to the called func, + alongside the passed arguments + :param func_name: A unique identifier of the function. Default the function name without the namespace. + For example Class.foo() becomes 'foo' + :param task_name: The newly create Task name. Default: the calling Task name + function name + :param kwargs: name specific arguments for the target function. + These arguments will appear under the configuration, "Function" section + + :return Task: Return the newly created Task or None if running remotely and execution is skipped + """ + if not self.is_main_task(): + raise ValueError("Only the main Task object can call execute_function_remotely()") + if not callable(func): + raise ValueError("func must be callable") + if not Session.check_min_api_version('2.9'): + raise ValueError("Remote function execution is not supported, " + "please upgrade to the latest server version") + + func_name = str(func_name or func.__name__).strip() + if func_name in self._remote_functions_generated: + raise ValueError("Function name must be unique, a function by the name '{}' " + "was already created by this Task.".format(func_name)) + + section_name = 'Function' + tag_name = 'func' + func_marker = '__func_readonly__' + + # sanitize the dict, leave only basic types that we might want to override later in the UI + func_params = {k: v for k, v in kwargs.items() if verify_basic_value(v)} + func_params[func_marker] = func_name + + # do not query if we are running locally, there is no need. + task_func_marker = self.running_locally() or self.get_parameter('{}/{}'.format(section_name, func_marker)) + + # if we are running locally or if we are running remotely but we are not a forked tasks + # condition explained: + # (1) running in development mode creates all the forked tasks + # (2) running remotely but this is not one of the forked tasks (i.e. it is missing the fork tag attribute) + if self.running_locally() or not task_func_marker: + self._wait_for_repo_detection(300) + task = self.clone(self, name=task_name or '{} <{}>'.format(self.name, func_name), parent=self.id) + task.set_system_tags((task.get_system_tags() or []) + [tag_name]) + task.connect(func_params, name=section_name) + self._remote_functions_generated[func_name] = task.id + return task + + # check if we are one of the generated functions and if this is us, + # if we are not the correct function, not do nothing and leave + if task_func_marker != func_name: + self._remote_functions_generated[func_name] = len(self._remote_functions_generated) + 1 + return + + # mark this is us: + self._remote_functions_generated[func_name] = self.id + + # this is us for sure, let's update the arguments and call the function + self.connect(func_params, name=section_name) + func_params.pop(func_marker, None) + kwargs.update(func_params) + func(**kwargs) + # This is it, leave the process + exit(0) def wait_for_status( self, @@ -2349,6 +2435,28 @@ class Task(_Task): self.set_parameters(attr_class.to_dict(), __parameters_prefix=name) return attr_class + def _connect_object(self, an_object, name=None): + def verify_type(key, value): + if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types): + return False + # verify everything is json able (i.e. basic types) + try: + json.dumps(value) + return True + except TypeError: + return False + + a_dict = {k: v for k, v in an_object.__dict__.items() if verify_type(k, v)} + if running_remotely() and self.is_main_task(): + a_dict = self._connect_dictionary(a_dict, name) + for k, v in a_dict.items(): + if getattr(an_object, k, None) != a_dict[k]: + setattr(an_object, k, v) + + return an_object + else: + return self._connect_dictionary(a_dict, name) + def _validate(self, check_output_dest_credentials=False): if running_remotely(): super(Task, self)._validate(check_output_dest_credentials=False)