mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add Task.connect support for class / instance objects
Add task.execute_function_remotely(...) (issue #230)
This commit is contained in:
parent
de85580faa
commit
501e27057b
122
trains/task.py
122
trains/task.py
@ -17,7 +17,7 @@ try:
|
||||
except ImportError:
|
||||
from collections import Sequence as CollectionsSequence
|
||||
|
||||
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING
|
||||
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable
|
||||
|
||||
import psutil
|
||||
import six
|
||||
@ -162,6 +162,7 @@ class Task(_Task):
|
||||
self._detect_repo_async_thread = None
|
||||
self._resource_monitor = None
|
||||
self._calling_filename = None
|
||||
self._remote_functions_generated = {}
|
||||
# register atexit, so that we mark the task as stopped
|
||||
self._at_exit_called = False
|
||||
|
||||
@ -937,7 +938,9 @@ class Task(_Task):
|
||||
- argparse - An argparse object for parameters.
|
||||
- dict - A dictionary for parameters.
|
||||
- TaskParameters - A TaskParameters object.
|
||||
- model - A model object for initial model warmup, or for model update/snapshot uploading.
|
||||
- Model - A model object for initial model warmup, or for model update/snapshot uploading.
|
||||
- Class type - A Class type, storing all class properties (excluding '_' prefix properties)
|
||||
- Object - A class instance, storing all instance properties (excluding '_' prefix properties)
|
||||
|
||||
:param str name: A section name associated with the connected object. Default: 'General'
|
||||
Currently only supported for `dict` / `TaskParameter` objects
|
||||
@ -949,13 +952,15 @@ class Task(_Task):
|
||||
|
||||
:raise: Raise an exception on unsupported objects.
|
||||
"""
|
||||
|
||||
# dispatching by match order
|
||||
dispatch = (
|
||||
(OutputModel, self._connect_output_model),
|
||||
(InputModel, self._connect_input_model),
|
||||
(ArgumentParser, self._connect_argparse),
|
||||
(dict, self._connect_dictionary),
|
||||
(TaskParameters, self._connect_task_parameters),
|
||||
(type, self._connect_object),
|
||||
(object, self._connect_object),
|
||||
)
|
||||
|
||||
multi_config_support = Session.check_min_api_version('2.9')
|
||||
@ -1657,7 +1662,7 @@ class Task(_Task):
|
||||
return True
|
||||
|
||||
def execute_remotely(self, queue_name=None, clone=False, exit_process=True):
|
||||
# type: (Optional[str], bool, bool) -> ()
|
||||
# type: (Optional[str], bool, bool) -> Optional[Task]
|
||||
"""
|
||||
If task is running locally (i.e., not by ``trains-agent``), then clone the Task and enqueue it for remote
|
||||
execution; or, stop the execution of the current Task, reset its state, and enqueue it. If ``exit==True``,
|
||||
@ -1684,10 +1689,12 @@ class Task(_Task):
|
||||
.. warning::
|
||||
|
||||
If ``clone==False``, then ``exit_process`` must be ``True``.
|
||||
|
||||
:return Task: return the task object of the newly generated remotely excuting task
|
||||
"""
|
||||
# do nothing, we are running remotely
|
||||
if running_remotely():
|
||||
return
|
||||
if running_remotely() and self.is_main_task():
|
||||
return None
|
||||
|
||||
if not clone and not exit_process:
|
||||
raise ValueError(
|
||||
@ -1727,7 +1734,86 @@ class Task(_Task):
|
||||
LoggerRoot.get_base_logger().warning('Terminating local execution process')
|
||||
exit(0)
|
||||
|
||||
return
|
||||
return task
|
||||
|
||||
def create_function_task(self, func, func_name=None, task_name=None, **kwargs):
|
||||
# type: (Callable, Optional[str], Optional[str], **Optional[Any]) -> Optional[Task]
|
||||
"""
|
||||
Create a new task, and call ``func`` with the specified kwargs.
|
||||
One can think of this call as remote forking, where the newly created instance is the new Task
|
||||
calling the specified func with the appropriate kwargs and leave once the func terminates.
|
||||
Notice that a remote executed function cannot create another child remote executed function.
|
||||
|
||||
.. note::
|
||||
- Must be called from the main Task, i.e. the one created by Task.init(...)
|
||||
- The remote Tasks inherits the environment from the creating Task
|
||||
- In the remote Task, the entrypoint is the same as the creating Task
|
||||
- In the remote Task, the execution is the same until reaching this function call
|
||||
|
||||
:param func: A function to execute remotely as a single Task.
|
||||
On the remote executed Task the entry-point and the environment are copied from this
|
||||
calling process, only this function call redirect the the execution flow to the called func,
|
||||
alongside the passed arguments
|
||||
:param func_name: A unique identifier of the function. Default the function name without the namespace.
|
||||
For example Class.foo() becomes 'foo'
|
||||
:param task_name: The newly create Task name. Default: the calling Task name + function name
|
||||
:param kwargs: name specific arguments for the target function.
|
||||
These arguments will appear under the configuration, "Function" section
|
||||
|
||||
:return Task: Return the newly created Task or None if running remotely and execution is skipped
|
||||
"""
|
||||
if not self.is_main_task():
|
||||
raise ValueError("Only the main Task object can call execute_function_remotely()")
|
||||
if not callable(func):
|
||||
raise ValueError("func must be callable")
|
||||
if not Session.check_min_api_version('2.9'):
|
||||
raise ValueError("Remote function execution is not supported, "
|
||||
"please upgrade to the latest server version")
|
||||
|
||||
func_name = str(func_name or func.__name__).strip()
|
||||
if func_name in self._remote_functions_generated:
|
||||
raise ValueError("Function name must be unique, a function by the name '{}' "
|
||||
"was already created by this Task.".format(func_name))
|
||||
|
||||
section_name = 'Function'
|
||||
tag_name = 'func'
|
||||
func_marker = '__func_readonly__'
|
||||
|
||||
# sanitize the dict, leave only basic types that we might want to override later in the UI
|
||||
func_params = {k: v for k, v in kwargs.items() if verify_basic_value(v)}
|
||||
func_params[func_marker] = func_name
|
||||
|
||||
# do not query if we are running locally, there is no need.
|
||||
task_func_marker = self.running_locally() or self.get_parameter('{}/{}'.format(section_name, func_marker))
|
||||
|
||||
# if we are running locally or if we are running remotely but we are not a forked tasks
|
||||
# condition explained:
|
||||
# (1) running in development mode creates all the forked tasks
|
||||
# (2) running remotely but this is not one of the forked tasks (i.e. it is missing the fork tag attribute)
|
||||
if self.running_locally() or not task_func_marker:
|
||||
self._wait_for_repo_detection(300)
|
||||
task = self.clone(self, name=task_name or '{} <{}>'.format(self.name, func_name), parent=self.id)
|
||||
task.set_system_tags((task.get_system_tags() or []) + [tag_name])
|
||||
task.connect(func_params, name=section_name)
|
||||
self._remote_functions_generated[func_name] = task.id
|
||||
return task
|
||||
|
||||
# check if we are one of the generated functions and if this is us,
|
||||
# if we are not the correct function, not do nothing and leave
|
||||
if task_func_marker != func_name:
|
||||
self._remote_functions_generated[func_name] = len(self._remote_functions_generated) + 1
|
||||
return
|
||||
|
||||
# mark this is us:
|
||||
self._remote_functions_generated[func_name] = self.id
|
||||
|
||||
# this is us for sure, let's update the arguments and call the function
|
||||
self.connect(func_params, name=section_name)
|
||||
func_params.pop(func_marker, None)
|
||||
kwargs.update(func_params)
|
||||
func(**kwargs)
|
||||
# This is it, leave the process
|
||||
exit(0)
|
||||
|
||||
def wait_for_status(
|
||||
self,
|
||||
@ -2349,6 +2435,28 @@ class Task(_Task):
|
||||
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
|
||||
return attr_class
|
||||
|
||||
def _connect_object(self, an_object, name=None):
|
||||
def verify_type(key, value):
|
||||
if str(key).startswith('_') or not isinstance(value, self._parameters_allowed_types):
|
||||
return False
|
||||
# verify everything is json able (i.e. basic types)
|
||||
try:
|
||||
json.dumps(value)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
a_dict = {k: v for k, v in an_object.__dict__.items() if verify_type(k, v)}
|
||||
if running_remotely() and self.is_main_task():
|
||||
a_dict = self._connect_dictionary(a_dict, name)
|
||||
for k, v in a_dict.items():
|
||||
if getattr(an_object, k, None) != a_dict[k]:
|
||||
setattr(an_object, k, v)
|
||||
|
||||
return an_object
|
||||
else:
|
||||
return self._connect_dictionary(a_dict, name)
|
||||
|
||||
def _validate(self, check_output_dest_credentials=False):
|
||||
if running_remotely():
|
||||
super(Task, self)._validate(check_output_dest_credentials=False)
|
||||
|
Loading…
Reference in New Issue
Block a user