Add Task.init auto_connect_frameworks fine granularity control

This commit is contained in:
allegroai 2019-11-08 22:30:09 +02:00
parent 9362831269
commit 4372dda696
2 changed files with 74 additions and 12 deletions

View File

@ -153,8 +153,13 @@ class Task(_Task):
For example: trains[s3], trains[gs], trains[azure]
:param auto_connect_arg_parser: Automatically grab the ArgParser and connect it with the task.
if set to false, you can manually connect the ArgParser with task.connect(parser)
:param auto_connect_frameworks: If true automatically patch MatplotLib, Keras callbacks, and TensorBoard/X to
serialize plots, graphs and model location to trains backend (in addition to original output destination)
:param auto_connect_frameworks: If True automatically patch MatplotLib, XGBoost, scikit-learn,
Keras callbacks, and TensorBoard/X to serialize plots, graphs and model location to trains backend
(in addition to original output destination).
Fine grained control is possible by passing a dictionary instead of a Boolean.
Missing keys are considered to have True value, empty dictionary is considered as False, full example:
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
'xgboost': True, 'scikit': True}
:param auto_resource_monitoring: If true, machine vitals will be sent along side the task scalars,
Resources graphs will appear under the title ':resource monitor:' in the scalars tab.
:return: Task() object
@ -258,12 +263,18 @@ class Task(_Task):
# patch OS forking
PatchOsFork.patch_fork()
if auto_connect_frameworks:
PatchedJoblib.update_current_task(task)
PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task)
TensorflowBinding.update_current_task(task)
PatchPyTorchModelIO.update_current_task(task)
PatchXGBoostModelIO.update_current_task(task)
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True):
PatchedJoblib.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):
PatchedMatplotlib.update_current_task(Task.__main_task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True):
PatchAbsl.update_current_task(Task.__main_task)
TensorflowBinding.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
PatchXGBoostModelIO.update_current_task(task)
if auto_resource_monitoring:
task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start()
@ -705,13 +716,13 @@ class Task(_Task):
def get_last_iteration(self):
"""
Return the last reported iteration (i.e. the maximum iteration the task reported a metric for)
Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for)
Notice, this is not a cached call, it will ask the backend for the answer (no local caching)
:return: last reported iteration number (integer)
"""
self.reload()
return self.data.last_iteration
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
def set_last_iteration(self, last_iteration):
"""
@ -1020,7 +1031,7 @@ class Task(_Task):
def _connect_dictionary(self, dictionary):
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if running_remotely() and self.is_main_task():
if running_remotely():
dictionary = self._arguments.copy_to_dict(dictionary)
else:
dictionary = self._arguments.copy_from_dict(dictionary)
@ -1544,7 +1555,8 @@ class Task(_Task):
# compare after casting to string to avoid enum instance issues
# remember we might have replaced the api version by now, so enums are different
return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares)
return all(six.text_type(server_data) == six.text_type(task_data.get(task_data_key))
for server_data, task_data_key in compares)
@classmethod
def __close_timed_out_task(cls, task_data):

View File

@ -0,0 +1,50 @@
class ProxyDictPostWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, update_obj, update_func, *args, **kwargs):
super(ProxyDictPostWrite, self).__init__(*args, **kwargs)
self._update_func = None
for k, i in self.items():
if isinstance(i, dict):
self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
self._update_obj = update_obj
self._update_func = update_func
def __setitem__(self, key, value):
super(ProxyDictPostWrite, self).__setitem__(key, value)
self._set_callback()
def _set_callback(self, *_):
if self._update_func:
self._update_func(self._update_obj, self)
class ProxyDictPreWrite(dict):
""" Dictionary wrapper that prevents modifications to the dictionary """
def __init__(self, update_obj, update_func, *args, **kwargs):
super(ProxyDictPreWrite, self).__init__(*args, **kwargs)
self._update_func = None
for k, i in self.items():
if isinstance(i, dict):
self.update({k: ProxyDictPreWrite(k, self._nested_callback, **i)})
self._update_obj = update_obj
self._update_func = update_func
def __setitem__(self, key, value):
key_value = self._set_callback((key, value,))
if key_value:
super(ProxyDictPreWrite, self).__setitem__(*key_value)
def _set_callback(self, key_value, *_):
if self._update_func:
res = self._update_func(self._update_obj, key_value)
if not res:
return None
return res
return key_value
def _nested_callback(self, prefix, key_value):
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))