diff --git a/trains/task.py b/trains/task.py index 95232842..2c4ca443 100644 --- a/trains/task.py +++ b/trains/task.py @@ -153,8 +153,13 @@ class Task(_Task): For example: trains[s3], trains[gs], trains[azure] :param auto_connect_arg_parser: Automatically grab the ArgParser and connect it with the task. if set to false, you can manually connect the ArgParser with task.connect(parser) - :param auto_connect_frameworks: If true automatically patch MatplotLib, Keras callbacks, and TensorBoard/X to - serialize plots, graphs and model location to trains backend (in addition to original output destination) + :param auto_connect_frameworks: If True automatically patch MatplotLib, XGBoost, scikit-learn, + Keras callbacks, and TensorBoard/X to serialize plots, graphs and model location to trains backend + (in addition to original output destination). + Fine grained control is possible by passing a dictionary instead of a Boolean. + Missing keys are considered to have True value, empty dictionary is considered as False, full example: + auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True, + 'xgboost': True, 'scikit': True} :param auto_resource_monitoring: If true, machine vitals will be sent along side the task scalars, Resources graphs will appear under the title ':resource monitor:' in the scalars tab. :return: Task() object @@ -258,12 +263,18 @@ class Task(_Task): # patch OS forking PatchOsFork.patch_fork() if auto_connect_frameworks: - PatchedJoblib.update_current_task(task) - PatchedMatplotlib.update_current_task(Task.__main_task) - PatchAbsl.update_current_task(Task.__main_task) - TensorflowBinding.update_current_task(task) - PatchPyTorchModelIO.update_current_task(task) - PatchXGBoostModelIO.update_current_task(task) + is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('scikit', True): + PatchedJoblib.update_current_task(task) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('matplotlib', True): + PatchedMatplotlib.update_current_task(Task.__main_task) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True): + PatchAbsl.update_current_task(Task.__main_task) + TensorflowBinding.update_current_task(task) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True): + PatchPyTorchModelIO.update_current_task(task) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): + PatchXGBoostModelIO.update_current_task(task) if auto_resource_monitoring: task._resource_monitor = ResourceMonitor(task) task._resource_monitor.start() @@ -705,13 +716,13 @@ class Task(_Task): def get_last_iteration(self): """ - Return the last reported iteration (i.e. the maximum iteration the task reported a metric for) + Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for) Notice, this is not a cached call, it will ask the backend for the answer (no local caching) :return: last reported iteration number (integer) """ self.reload() - return self.data.last_iteration + return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0) def set_last_iteration(self, last_iteration): """ @@ -1020,7 +1031,7 @@ class Task(_Task): def _connect_dictionary(self, dictionary): self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) - if running_remotely() and self.is_main_task(): + if running_remotely(): dictionary = self._arguments.copy_to_dict(dictionary) else: dictionary = self._arguments.copy_from_dict(dictionary) @@ -1544,7 +1555,8 @@ class Task(_Task): # compare after casting to string to avoid enum instance issues # remember we might have replaced the api version by now, so enums are different - return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares) + return all(six.text_type(server_data) == six.text_type(task_data.get(task_data_key)) + for server_data, task_data_key in compares) @classmethod def __close_timed_out_task(cls, task_data): diff --git a/trains/utilities/proxy_object.py b/trains/utilities/proxy_object.py new file mode 100644 index 00000000..2f910f3d --- /dev/null +++ b/trains/utilities/proxy_object.py @@ -0,0 +1,50 @@ + + +class ProxyDictPostWrite(dict): + """ Dictionary wrapper that updates an arguments instance on any item set in the dictionary """ + + def __init__(self, update_obj, update_func, *args, **kwargs): + super(ProxyDictPostWrite, self).__init__(*args, **kwargs) + self._update_func = None + for k, i in self.items(): + if isinstance(i, dict): + self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)}) + self._update_obj = update_obj + self._update_func = update_func + + def __setitem__(self, key, value): + super(ProxyDictPostWrite, self).__setitem__(key, value) + self._set_callback() + + def _set_callback(self, *_): + if self._update_func: + self._update_func(self._update_obj, self) + + +class ProxyDictPreWrite(dict): + """ Dictionary wrapper that prevents modifications to the dictionary """ + + def __init__(self, update_obj, update_func, *args, **kwargs): + super(ProxyDictPreWrite, self).__init__(*args, **kwargs) + self._update_func = None + for k, i in self.items(): + if isinstance(i, dict): + self.update({k: ProxyDictPreWrite(k, self._nested_callback, **i)}) + self._update_obj = update_obj + self._update_func = update_func + + def __setitem__(self, key, value): + key_value = self._set_callback((key, value,)) + if key_value: + super(ProxyDictPreWrite, self).__setitem__(*key_value) + + def _set_callback(self, key_value, *_): + if self._update_func: + res = self._update_func(self._update_obj, key_value) + if not res: + return None + return res + return key_value + + def _nested_callback(self, prefix, key_value): + return self._set_callback((prefix+'.'+key_value[0], key_value[1],))