Rename TF patching functions for better readability

This commit is contained in:
allegroai 2021-04-25 10:43:11 +03:00
parent dc6c8cfddc
commit a9f52a468c
2 changed files with 17 additions and 12 deletions

View File

@ -26,14 +26,17 @@ except ImportError:
class TensorflowBinding(object):
@classmethod
def update_current_task(cls, task, save_models, report_tensorboard):
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True):
if not task:
IsTensorboardInit.clear_tensorboard_used()
if report_tensorboard:
EventTrainsWriter.update_current_task(task)
EventTrainsWriter.update_current_task(task)
if patch_reporting:
PatchSummaryToEventTransformer.update_current_task(task)
PatchTensorFlowEager.update_current_task(task)
if save_models:
if patch_model_io:
PatchKerasModelIO.update_current_task(task)
PatchTensorflowModelIO.update_current_task(task)
PatchTensorflow2ModelIO.update_current_task(task)
@ -1145,11 +1148,11 @@ class PatchTensorFlowEager(object):
PatchTensorFlowEager.defaults_dict.update(kwargs)
PatchTensorFlowEager.__main_task = task
# make sure we patched the SummaryToEventTransformer
PatchTensorFlowEager._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_model_checkpoint)
PatchTensorFlowEager._patch_summary_ops()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
@staticmethod
def _patch_model_checkpoint():
def _patch_summary_ops():
if PatchTensorFlowEager.__original_fn_scalar is not None:
return
if 'tensorflow' in sys.modules:

View File

@ -545,11 +545,13 @@ class Task(_Task):
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
or auto_connect_frameworks.get('tensorboard', True):
PatchAbsl.update_current_task(Task.__main_task)
TensorflowBinding.update_current_task(task,
is_auto_connect_frameworks_bool or
auto_connect_frameworks.get('tensorflow', True),
is_auto_connect_frameworks_bool or
auto_connect_frameworks.get('tensorboard', True))
TensorflowBinding.update_current_task(
task,
patch_reporting=(is_auto_connect_frameworks_bool
or auto_connect_frameworks.get('tensorboard', True)),
patch_model_io=(is_auto_connect_frameworks_bool
or auto_connect_frameworks.get('tensorflow', True)),
)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):