Add Task.init auto_connect_frameworks fine granularity control

This commit is contained in:
allegroai 2019-11-08 22:30:09 +02:00
parent 9362831269
commit 4372dda696
2 changed files with 74 additions and 12 deletions

View File

@ -153,8 +153,13 @@ class Task(_Task):
For example: trains[s3], trains[gs], trains[azure] For example: trains[s3], trains[gs], trains[azure]
:param auto_connect_arg_parser: Automatically grab the ArgParser and connect it with the task. :param auto_connect_arg_parser: Automatically grab the ArgParser and connect it with the task.
if set to false, you can manually connect the ArgParser with task.connect(parser) if set to false, you can manually connect the ArgParser with task.connect(parser)
:param auto_connect_frameworks: If true automatically patch MatplotLib, Keras callbacks, and TensorBoard/X to :param auto_connect_frameworks: If True automatically patch MatplotLib, XGBoost, scikit-learn,
serialize plots, graphs and model location to trains backend (in addition to original output destination) Keras callbacks, and TensorBoard/X to serialize plots, graphs and model location to trains backend
(in addition to original output destination).
Fine grained control is possible by passing a dictionary instead of a Boolean.
Missing keys are considered to have True value, empty dictionary is considered as False, full example:
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
'xgboost': True, 'scikit': True}
:param auto_resource_monitoring: If true, machine vitals will be sent along side the task scalars, :param auto_resource_monitoring: If true, machine vitals will be sent along side the task scalars,
Resources graphs will appear under the title ':resource monitor:' in the scalars tab. Resources graphs will appear under the title ':resource monitor:' in the scalars tab.
:return: Task() object :return: Task() object
@ -258,12 +263,18 @@ class Task(_Task):
# patch OS forking # patch OS forking
PatchOsFork.patch_fork() PatchOsFork.patch_fork()
if auto_connect_frameworks: if auto_connect_frameworks:
PatchedJoblib.update_current_task(task) is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
PatchedMatplotlib.update_current_task(Task.__main_task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True):
PatchAbsl.update_current_task(Task.__main_task) PatchedJoblib.update_current_task(task)
TensorflowBinding.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):
PatchPyTorchModelIO.update_current_task(task) PatchedMatplotlib.update_current_task(Task.__main_task)
PatchXGBoostModelIO.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True):
PatchAbsl.update_current_task(Task.__main_task)
TensorflowBinding.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
PatchXGBoostModelIO.update_current_task(task)
if auto_resource_monitoring: if auto_resource_monitoring:
task._resource_monitor = ResourceMonitor(task) task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start() task._resource_monitor.start()
@ -705,13 +716,13 @@ class Task(_Task):
def get_last_iteration(self): def get_last_iteration(self):
""" """
Return the last reported iteration (i.e. the maximum iteration the task reported a metric for) Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for)
Notice, this is not a cached call, it will ask the backend for the answer (no local caching) Notice, this is not a cached call, it will ask the backend for the answer (no local caching)
:return: last reported iteration number (integer) :return: last reported iteration number (integer)
""" """
self.reload() self.reload()
return self.data.last_iteration return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
def set_last_iteration(self, last_iteration): def set_last_iteration(self, last_iteration):
""" """
@ -1020,7 +1031,7 @@ class Task(_Task):
def _connect_dictionary(self, dictionary): def _connect_dictionary(self, dictionary):
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if running_remotely() and self.is_main_task(): if running_remotely():
dictionary = self._arguments.copy_to_dict(dictionary) dictionary = self._arguments.copy_to_dict(dictionary)
else: else:
dictionary = self._arguments.copy_from_dict(dictionary) dictionary = self._arguments.copy_from_dict(dictionary)
@ -1544,7 +1555,8 @@ class Task(_Task):
# compare after casting to string to avoid enum instance issues # compare after casting to string to avoid enum instance issues
# remember we might have replaced the api version by now, so enums are different # remember we might have replaced the api version by now, so enums are different
return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares) return all(six.text_type(server_data) == six.text_type(task_data.get(task_data_key))
for server_data, task_data_key in compares)
@classmethod @classmethod
def __close_timed_out_task(cls, task_data): def __close_timed_out_task(cls, task_data):

View File

@ -0,0 +1,50 @@
class ProxyDictPostWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, update_obj, update_func, *args, **kwargs):
super(ProxyDictPostWrite, self).__init__(*args, **kwargs)
self._update_func = None
for k, i in self.items():
if isinstance(i, dict):
self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
self._update_obj = update_obj
self._update_func = update_func
def __setitem__(self, key, value):
super(ProxyDictPostWrite, self).__setitem__(key, value)
self._set_callback()
def _set_callback(self, *_):
if self._update_func:
self._update_func(self._update_obj, self)
class ProxyDictPreWrite(dict):
""" Dictionary wrapper that prevents modifications to the dictionary """
def __init__(self, update_obj, update_func, *args, **kwargs):
super(ProxyDictPreWrite, self).__init__(*args, **kwargs)
self._update_func = None
for k, i in self.items():
if isinstance(i, dict):
self.update({k: ProxyDictPreWrite(k, self._nested_callback, **i)})
self._update_obj = update_obj
self._update_func = update_func
def __setitem__(self, key, value):
key_value = self._set_callback((key, value,))
if key_value:
super(ProxyDictPreWrite, self).__setitem__(*key_value)
def _set_callback(self, key_value, *_):
if self._update_func:
res = self._update_func(self._update_obj, key_value)
if not res:
return None
return res
return key_value
def _nested_callback(self, prefix, key_value):
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))