mirror of
https://github.com/clearml/clearml
synced 2025-04-25 00:37:52 +00:00
Add Task.init auto_connect_frameworks fine granularity control
This commit is contained in:
parent
9362831269
commit
4372dda696
@ -153,8 +153,13 @@ class Task(_Task):
|
||||
For example: trains[s3], trains[gs], trains[azure]
|
||||
:param auto_connect_arg_parser: Automatically grab the ArgParser and connect it with the task.
|
||||
if set to false, you can manually connect the ArgParser with task.connect(parser)
|
||||
:param auto_connect_frameworks: If true automatically patch MatplotLib, Keras callbacks, and TensorBoard/X to
|
||||
serialize plots, graphs and model location to trains backend (in addition to original output destination)
|
||||
:param auto_connect_frameworks: If True automatically patch MatplotLib, XGBoost, scikit-learn,
|
||||
Keras callbacks, and TensorBoard/X to serialize plots, graphs and model location to trains backend
|
||||
(in addition to original output destination).
|
||||
Fine grained control is possible by passing a dictionary instead of a Boolean.
|
||||
Missing keys are considered to have True value, empty dictionary is considered as False, full example:
|
||||
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
|
||||
'xgboost': True, 'scikit': True}
|
||||
:param auto_resource_monitoring: If true, machine vitals will be sent along side the task scalars,
|
||||
Resources graphs will appear under the title ':resource monitor:' in the scalars tab.
|
||||
:return: Task() object
|
||||
@ -258,12 +263,18 @@ class Task(_Task):
|
||||
# patch OS forking
|
||||
PatchOsFork.patch_fork()
|
||||
if auto_connect_frameworks:
|
||||
PatchedJoblib.update_current_task(task)
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
TensorflowBinding.update_current_task(task)
|
||||
PatchPyTorchModelIO.update_current_task(task)
|
||||
PatchXGBoostModelIO.update_current_task(task)
|
||||
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True):
|
||||
PatchedJoblib.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True):
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True):
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
TensorflowBinding.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
||||
PatchPyTorchModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||
PatchXGBoostModelIO.update_current_task(task)
|
||||
if auto_resource_monitoring:
|
||||
task._resource_monitor = ResourceMonitor(task)
|
||||
task._resource_monitor.start()
|
||||
@ -705,13 +716,13 @@ class Task(_Task):
|
||||
|
||||
def get_last_iteration(self):
|
||||
"""
|
||||
Return the last reported iteration (i.e. the maximum iteration the task reported a metric for)
|
||||
Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for)
|
||||
Notice, this is not a cached call, it will ask the backend for the answer (no local caching)
|
||||
|
||||
:return: last reported iteration number (integer)
|
||||
"""
|
||||
self.reload()
|
||||
return self.data.last_iteration
|
||||
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
|
||||
|
||||
def set_last_iteration(self, last_iteration):
|
||||
"""
|
||||
@ -1020,7 +1031,7 @@ class Task(_Task):
|
||||
def _connect_dictionary(self, dictionary):
|
||||
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
|
||||
|
||||
if running_remotely() and self.is_main_task():
|
||||
if running_remotely():
|
||||
dictionary = self._arguments.copy_to_dict(dictionary)
|
||||
else:
|
||||
dictionary = self._arguments.copy_from_dict(dictionary)
|
||||
@ -1544,7 +1555,8 @@ class Task(_Task):
|
||||
|
||||
# compare after casting to string to avoid enum instance issues
|
||||
# remember we might have replaced the api version by now, so enums are different
|
||||
return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares)
|
||||
return all(six.text_type(server_data) == six.text_type(task_data.get(task_data_key))
|
||||
for server_data, task_data_key in compares)
|
||||
|
||||
@classmethod
|
||||
def __close_timed_out_task(cls, task_data):
|
||||
|
50
trains/utilities/proxy_object.py
Normal file
50
trains/utilities/proxy_object.py
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
|
||||
class ProxyDictPostWrite(dict):
|
||||
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||
|
||||
def __init__(self, update_obj, update_func, *args, **kwargs):
|
||||
super(ProxyDictPostWrite, self).__init__(*args, **kwargs)
|
||||
self._update_func = None
|
||||
for k, i in self.items():
|
||||
if isinstance(i, dict):
|
||||
self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
|
||||
self._update_obj = update_obj
|
||||
self._update_func = update_func
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(ProxyDictPostWrite, self).__setitem__(key, value)
|
||||
self._set_callback()
|
||||
|
||||
def _set_callback(self, *_):
|
||||
if self._update_func:
|
||||
self._update_func(self._update_obj, self)
|
||||
|
||||
|
||||
class ProxyDictPreWrite(dict):
|
||||
""" Dictionary wrapper that prevents modifications to the dictionary """
|
||||
|
||||
def __init__(self, update_obj, update_func, *args, **kwargs):
|
||||
super(ProxyDictPreWrite, self).__init__(*args, **kwargs)
|
||||
self._update_func = None
|
||||
for k, i in self.items():
|
||||
if isinstance(i, dict):
|
||||
self.update({k: ProxyDictPreWrite(k, self._nested_callback, **i)})
|
||||
self._update_obj = update_obj
|
||||
self._update_func = update_func
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
key_value = self._set_callback((key, value,))
|
||||
if key_value:
|
||||
super(ProxyDictPreWrite, self).__setitem__(*key_value)
|
||||
|
||||
def _set_callback(self, key_value, *_):
|
||||
if self._update_func:
|
||||
res = self._update_func(self._update_obj, key_value)
|
||||
if not res:
|
||||
return None
|
||||
return res
|
||||
return key_value
|
||||
|
||||
def _nested_callback(self, prefix, key_value):
|
||||
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))
|
Loading…
Reference in New Issue
Block a user