mirror of
https://github.com/clearml/clearml
synced 2025-04-25 00:37:52 +00:00
Add Task.init auto_connect_frameworks fine granularity control
This commit is contained in:
parent
9362831269
commit
4372dda696
@ -153,8 +153,13 @@ class Task(_Task):
|
|||||||
For example: trains[s3], trains[gs], trains[azure]
|
For example: trains[s3], trains[gs], trains[azure]
|
||||||
:param auto_connect_arg_parser: Automatically grab the ArgParser and connect it with the task.
|
:param auto_connect_arg_parser: Automatically grab the ArgParser and connect it with the task.
|
||||||
if set to false, you can manually connect the ArgParser with task.connect(parser)
|
if set to false, you can manually connect the ArgParser with task.connect(parser)
|
||||||
:param auto_connect_frameworks: If true automatically patch MatplotLib, Keras callbacks, and TensorBoard/X to
|
:param auto_connect_frameworks: If True automatically patch MatplotLib, XGBoost, scikit-learn,
|
||||||
serialize plots, graphs and model location to trains backend (in addition to original output destination)
|
Keras callbacks, and TensorBoard/X to serialize plots, graphs and model location to trains backend
|
||||||
|
(in addition to original output destination).
|
||||||
|
Fine grained control is possible by passing a dictionary instead of a Boolean.
|
||||||
|
Missing keys are considered to have True value, empty dictionary is considered as False, full example:
|
||||||
|
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
|
||||||
|
'xgboost': True, 'scikit': True}
|
||||||
:param auto_resource_monitoring: If true, machine vitals will be sent along side the task scalars,
|
:param auto_resource_monitoring: If true, machine vitals will be sent along side the task scalars,
|
||||||
Resources graphs will appear under the title ':resource monitor:' in the scalars tab.
|
Resources graphs will appear under the title ':resource monitor:' in the scalars tab.
|
||||||
:return: Task() object
|
:return: Task() object
|
||||||
@ -258,12 +263,18 @@ class Task(_Task):
|
|||||||
# patch OS forking
|
# patch OS forking
|
||||||
PatchOsFork.patch_fork()
|
PatchOsFork.patch_fork()
|
||||||
if auto_connect_frameworks:
|
if auto_connect_frameworks:
|
||||||
PatchedJoblib.update_current_task(task)
|
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
|
||||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True):
|
||||||
PatchAbsl.update_current_task(Task.__main_task)
|
PatchedJoblib.update_current_task(task)
|
||||||
TensorflowBinding.update_current_task(task)
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):
|
||||||
PatchPyTorchModelIO.update_current_task(task)
|
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||||
PatchXGBoostModelIO.update_current_task(task)
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True):
|
||||||
|
PatchAbsl.update_current_task(Task.__main_task)
|
||||||
|
TensorflowBinding.update_current_task(task)
|
||||||
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
||||||
|
PatchPyTorchModelIO.update_current_task(task)
|
||||||
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||||
|
PatchXGBoostModelIO.update_current_task(task)
|
||||||
if auto_resource_monitoring:
|
if auto_resource_monitoring:
|
||||||
task._resource_monitor = ResourceMonitor(task)
|
task._resource_monitor = ResourceMonitor(task)
|
||||||
task._resource_monitor.start()
|
task._resource_monitor.start()
|
||||||
@ -705,13 +716,13 @@ class Task(_Task):
|
|||||||
|
|
||||||
def get_last_iteration(self):
|
def get_last_iteration(self):
|
||||||
"""
|
"""
|
||||||
Return the last reported iteration (i.e. the maximum iteration the task reported a metric for)
|
Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for)
|
||||||
Notice, this is not a cached call, it will ask the backend for the answer (no local caching)
|
Notice, this is not a cached call, it will ask the backend for the answer (no local caching)
|
||||||
|
|
||||||
:return: last reported iteration number (integer)
|
:return: last reported iteration number (integer)
|
||||||
"""
|
"""
|
||||||
self.reload()
|
self.reload()
|
||||||
return self.data.last_iteration
|
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
|
||||||
|
|
||||||
def set_last_iteration(self, last_iteration):
|
def set_last_iteration(self, last_iteration):
|
||||||
"""
|
"""
|
||||||
@ -1020,7 +1031,7 @@ class Task(_Task):
|
|||||||
def _connect_dictionary(self, dictionary):
|
def _connect_dictionary(self, dictionary):
|
||||||
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
|
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
|
||||||
|
|
||||||
if running_remotely() and self.is_main_task():
|
if running_remotely():
|
||||||
dictionary = self._arguments.copy_to_dict(dictionary)
|
dictionary = self._arguments.copy_to_dict(dictionary)
|
||||||
else:
|
else:
|
||||||
dictionary = self._arguments.copy_from_dict(dictionary)
|
dictionary = self._arguments.copy_from_dict(dictionary)
|
||||||
@ -1544,7 +1555,8 @@ class Task(_Task):
|
|||||||
|
|
||||||
# compare after casting to string to avoid enum instance issues
|
# compare after casting to string to avoid enum instance issues
|
||||||
# remember we might have replaced the api version by now, so enums are different
|
# remember we might have replaced the api version by now, so enums are different
|
||||||
return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares)
|
return all(six.text_type(server_data) == six.text_type(task_data.get(task_data_key))
|
||||||
|
for server_data, task_data_key in compares)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __close_timed_out_task(cls, task_data):
|
def __close_timed_out_task(cls, task_data):
|
||||||
|
50
trains/utilities/proxy_object.py
Normal file
50
trains/utilities/proxy_object.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
|
||||||
|
|
||||||
|
class ProxyDictPostWrite(dict):
|
||||||
|
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||||
|
|
||||||
|
def __init__(self, update_obj, update_func, *args, **kwargs):
|
||||||
|
super(ProxyDictPostWrite, self).__init__(*args, **kwargs)
|
||||||
|
self._update_func = None
|
||||||
|
for k, i in self.items():
|
||||||
|
if isinstance(i, dict):
|
||||||
|
self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
|
||||||
|
self._update_obj = update_obj
|
||||||
|
self._update_func = update_func
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
super(ProxyDictPostWrite, self).__setitem__(key, value)
|
||||||
|
self._set_callback()
|
||||||
|
|
||||||
|
def _set_callback(self, *_):
|
||||||
|
if self._update_func:
|
||||||
|
self._update_func(self._update_obj, self)
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyDictPreWrite(dict):
|
||||||
|
""" Dictionary wrapper that prevents modifications to the dictionary """
|
||||||
|
|
||||||
|
def __init__(self, update_obj, update_func, *args, **kwargs):
|
||||||
|
super(ProxyDictPreWrite, self).__init__(*args, **kwargs)
|
||||||
|
self._update_func = None
|
||||||
|
for k, i in self.items():
|
||||||
|
if isinstance(i, dict):
|
||||||
|
self.update({k: ProxyDictPreWrite(k, self._nested_callback, **i)})
|
||||||
|
self._update_obj = update_obj
|
||||||
|
self._update_func = update_func
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
key_value = self._set_callback((key, value,))
|
||||||
|
if key_value:
|
||||||
|
super(ProxyDictPreWrite, self).__setitem__(*key_value)
|
||||||
|
|
||||||
|
def _set_callback(self, key_value, *_):
|
||||||
|
if self._update_func:
|
||||||
|
res = self._update_func(self._update_obj, key_value)
|
||||||
|
if not res:
|
||||||
|
return None
|
||||||
|
return res
|
||||||
|
return key_value
|
||||||
|
|
||||||
|
def _nested_callback(self, prefix, key_value):
|
||||||
|
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))
|
Loading…
Reference in New Issue
Block a user