Rename TF patching functions for better readability

This commit is contained in:
allegroai 2021-04-25 10:43:11 +03:00
parent dc6c8cfddc
commit a9f52a468c
2 changed files with 17 additions and 12 deletions

View File

@ -26,14 +26,17 @@ except ImportError:
class TensorflowBinding(object): class TensorflowBinding(object):
@classmethod @classmethod
def update_current_task(cls, task, save_models, report_tensorboard): def update_current_task(cls, task, patch_reporting=True, patch_model_io=True):
if not task: if not task:
IsTensorboardInit.clear_tensorboard_used() IsTensorboardInit.clear_tensorboard_used()
if report_tensorboard:
EventTrainsWriter.update_current_task(task) EventTrainsWriter.update_current_task(task)
if patch_reporting:
PatchSummaryToEventTransformer.update_current_task(task) PatchSummaryToEventTransformer.update_current_task(task)
PatchTensorFlowEager.update_current_task(task) PatchTensorFlowEager.update_current_task(task)
if save_models:
if patch_model_io:
PatchKerasModelIO.update_current_task(task) PatchKerasModelIO.update_current_task(task)
PatchTensorflowModelIO.update_current_task(task) PatchTensorflowModelIO.update_current_task(task)
PatchTensorflow2ModelIO.update_current_task(task) PatchTensorflow2ModelIO.update_current_task(task)
@ -1145,11 +1148,11 @@ class PatchTensorFlowEager(object):
PatchTensorFlowEager.defaults_dict.update(kwargs) PatchTensorFlowEager.defaults_dict.update(kwargs)
PatchTensorFlowEager.__main_task = task PatchTensorFlowEager.__main_task = task
# make sure we patched the SummaryToEventTransformer # make sure we patched the SummaryToEventTransformer
PatchTensorFlowEager._patch_model_checkpoint() PatchTensorFlowEager._patch_summary_ops()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_model_checkpoint) PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
@staticmethod @staticmethod
def _patch_model_checkpoint(): def _patch_summary_ops():
if PatchTensorFlowEager.__original_fn_scalar is not None: if PatchTensorFlowEager.__original_fn_scalar is not None:
return return
if 'tensorflow' in sys.modules: if 'tensorflow' in sys.modules:

View File

@ -545,11 +545,13 @@ class Task(_Task):
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \ if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
or auto_connect_frameworks.get('tensorboard', True): or auto_connect_frameworks.get('tensorboard', True):
PatchAbsl.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task)
TensorflowBinding.update_current_task(task, TensorflowBinding.update_current_task(
is_auto_connect_frameworks_bool or task,
auto_connect_frameworks.get('tensorflow', True), patch_reporting=(is_auto_connect_frameworks_bool
is_auto_connect_frameworks_bool or or auto_connect_frameworks.get('tensorboard', True)),
auto_connect_frameworks.get('tensorboard', True)) patch_model_io=(is_auto_connect_frameworks_bool
or auto_connect_frameworks.get('tensorflow', True)),
)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True): if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
PatchPyTorchModelIO.update_current_task(task) PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):