Add Task.connect support for class / instance objects

Add task.execute_function_remotely(...)  (issue #230)
This commit is contained in:
allegroai 2020-11-08 00:17:21 +02:00
parent de85580faa
commit 501e27057b

View File

@ -17,7 +17,7 @@ try:
except ImportError:
from collections import Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable
import psutil
import six
@ -162,6 +162,7 @@ class Task(_Task):
self._detect_repo_async_thread = None
self._resource_monitor = None
self._calling_filename = None
self._remote_functions_generated = {}
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
@ -937,7 +938,9 @@ class Task(_Task):
- argparse - An argparse object for parameters.
- dict - A dictionary for parameters.
- TaskParameters - A TaskParameters object.
- model - A model object for initial model warmup, or for model update/snapshot uploading.
- Model - A model object for initial model warmup, or for model update/snapshot uploading.
- Class type - A Class type, storing all class properties (excluding '_' prefix properties)
- Object - A class instance, storing all instance properties (excluding '_' prefix properties)
:param str name: A section name associated with the connected object. Default: 'General'
Currently only supported for `dict` / `TaskParameter` objects
@ -949,13 +952,15 @@ class Task(_Task):
:raise: Raise an exception on unsupported objects.
"""
# dispatching by match order
dispatch = (
(OutputModel, self._connect_output_model),
(InputModel, self._connect_input_model),
(ArgumentParser, self._connect_argparse),
(dict, self._connect_dictionary),
(TaskParameters, self._connect_task_parameters),
(type, self._connect_object),
(object, self._connect_object),
)
multi_config_support = Session.check_min_api_version('2.9')
@ -1657,7 +1662,7 @@ class Task(_Task):
return True
def execute_remotely(self, queue_name=None, clone=False, exit_process=True):
# type: (Optional[str], bool, bool) -> ()
# type: (Optional[str], bool, bool) -> Optional[Task]
"""
If task is running locally (i.e., not by ``trains-agent``), then clone the Task and enqueue it for remote
execution; or, stop the execution of the current Task, reset its state, and enqueue it. If ``exit==True``,
@ -1684,10 +1689,12 @@ class Task(_Task):
.. warning::
If ``clone==False``, then ``exit_process`` must be ``True``.
:return Task: return the task object of the newly generated remotely excuting task
"""
# do nothing, we are running remotely
if running_remotely():
return
if running_remotely() and self.is_main_task():
return None
if not clone and not exit_process:
raise ValueError(
@ -1727,7 +1734,86 @@ class Task(_Task):
LoggerRoot.get_base_logger().warning('Terminating local execution process')
exit(0)
return
return task
def create_function_task(self, func, func_name=None, task_name=None, **kwargs):
# type: (Callable, Optional[str], Optional[str], **Optional[Any]) -> Optional[Task]
"""
Create a new task, and call ``func`` with the specified kwargs.
One can think of this call as remote forking, where the newly created instance is the new Task
calling the specified func with the appropriate kwargs and leave once the func terminates.
Notice that a remote executed function cannot create another child remote executed function.
.. note::
- Must be called from the main Task, i.e. the one created by Task.init(...)
- The remote Tasks inherits the environment from the creating Task
- In the remote Task, the entrypoint is the same as the creating Task
- In the remote Task, the execution is the same until reaching this function call
:param func: A function to execute remotely as a single Task.
On the remote executed Task the entry-point and the environment are copied from this
calling process, only this function call redirect the the execution flow to the called func,
alongside the passed arguments
:param func_name: A unique identifier of the function. Default the function name without the namespace.
For example Class.foo() becomes 'foo'
:param task_name: The newly create Task name. Default: the calling Task name + function name
:param kwargs: name specific arguments for the target function.
These arguments will appear under the configuration, "Function" section
:return Task: Return the newly created Task or None if running remotely and execution is skipped
"""
if not self.is_main_task():
raise ValueError("Only the main Task object can call execute_function_remotely()")
if not callable(func):
raise ValueError("func must be callable")
if not Session.check_min_api_version('2.9'):
raise ValueError("Remote function execution is not supported, "
"please upgrade to the latest server version")
func_name = str(func_name or func.__name__).strip()
if func_name in self._remote_functions_generated:
raise ValueError("Function name must be unique, a function by the name '{}' "
"was already created by this Task.".format(func_name))
section_name = 'Function'
tag_name = 'func'
func_marker = '__func_readonly__'
# sanitize the dict, leave only basic types that we might want to override later in the UI
func_params = {k: v for k, v in kwargs.items() if verify_basic_value(v)}
func_params[func_marker] = func_name
# do not query if we are running locally, there is no need.
task_func_marker = self.running_locally() or self.get_parameter('{}/{}'.format(section_name, func_marker))
# if we are running locally or if we are running remotely but we are not a forked tasks
# condition explained:
# (1) running in development mode creates all the forked tasks
# (2) running remotely but this is not one of the forked tasks (i.e. it is missing the fork tag attribute)
if self.running_locally() or not task_func_marker:
self._wait_for_repo_detection(300)
task = self.clone(self, name=task_name or '{} <{}>'.format(self.name, func_name), parent=self.id)
task.set_system_tags((task.get_system_tags() or []) + [tag_name])
task.connect(func_params, name=section_name)
self._remote_functions_generated[func_name] = task.id
return task
# check if we are one of the generated functions and if this is us,
# if we are not the correct function, not do nothing and leave
if task_func_marker != func_name:
self._remote_functions_generated[func_name] = len(self._remote_functions_generated) + 1
return
# mark this is us:
self._remote_functions_generated[func_name] = self.id
# this is us for sure, let's update the arguments and call the function
self.connect(func_params, name=section_name)
func_params.pop(func_marker, None)
kwargs.update(func_params)
func(**kwargs)
# This is it, leave the process
exit(0)
def wait_for_status(
self,
@ -2349,6 +2435,28 @@ class Task(_Task):
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
return attr_class
def _connect_object(self, an_object, name=None):
def verify_type(key, value):
if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types):
return False
# verify everything is json able (i.e. basic types)
try:
json.dumps(value)
return True
except TypeError:
return False
a_dict = {k: v for k, v in an_object.__dict__.items() if verify_type(k, v)}
if running_remotely() and self.is_main_task():
a_dict = self._connect_dictionary(a_dict, name)
for k, v in a_dict.items():
if getattr(an_object, k, None) != a_dict[k]:
setattr(an_object, k, v)
return an_object
else:
return self._connect_dictionary(a_dict, name)
def _validate(self, check_output_dest_credentials=False):
if running_remotely():
super(Task, self)._validate(check_output_dest_credentials=False)