mirror of
https://github.com/clearml/clearml
synced 2025-02-07 13:23:40 +00:00
Rename TF patching functions for better readability
This commit is contained in:
parent
dc6c8cfddc
commit
a9f52a468c
@ -26,14 +26,17 @@ except ImportError:
|
|||||||
|
|
||||||
class TensorflowBinding(object):
|
class TensorflowBinding(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_current_task(cls, task, save_models, report_tensorboard):
|
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True):
|
||||||
if not task:
|
if not task:
|
||||||
IsTensorboardInit.clear_tensorboard_used()
|
IsTensorboardInit.clear_tensorboard_used()
|
||||||
if report_tensorboard:
|
|
||||||
EventTrainsWriter.update_current_task(task)
|
EventTrainsWriter.update_current_task(task)
|
||||||
|
|
||||||
|
if patch_reporting:
|
||||||
PatchSummaryToEventTransformer.update_current_task(task)
|
PatchSummaryToEventTransformer.update_current_task(task)
|
||||||
PatchTensorFlowEager.update_current_task(task)
|
PatchTensorFlowEager.update_current_task(task)
|
||||||
if save_models:
|
|
||||||
|
if patch_model_io:
|
||||||
PatchKerasModelIO.update_current_task(task)
|
PatchKerasModelIO.update_current_task(task)
|
||||||
PatchTensorflowModelIO.update_current_task(task)
|
PatchTensorflowModelIO.update_current_task(task)
|
||||||
PatchTensorflow2ModelIO.update_current_task(task)
|
PatchTensorflow2ModelIO.update_current_task(task)
|
||||||
@ -1145,11 +1148,11 @@ class PatchTensorFlowEager(object):
|
|||||||
PatchTensorFlowEager.defaults_dict.update(kwargs)
|
PatchTensorFlowEager.defaults_dict.update(kwargs)
|
||||||
PatchTensorFlowEager.__main_task = task
|
PatchTensorFlowEager.__main_task = task
|
||||||
# make sure we patched the SummaryToEventTransformer
|
# make sure we patched the SummaryToEventTransformer
|
||||||
PatchTensorFlowEager._patch_model_checkpoint()
|
PatchTensorFlowEager._patch_summary_ops()
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_model_checkpoint():
|
def _patch_summary_ops():
|
||||||
if PatchTensorFlowEager.__original_fn_scalar is not None:
|
if PatchTensorFlowEager.__original_fn_scalar is not None:
|
||||||
return
|
return
|
||||||
if 'tensorflow' in sys.modules:
|
if 'tensorflow' in sys.modules:
|
||||||
|
@ -545,11 +545,13 @@ class Task(_Task):
|
|||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
|
||||||
or auto_connect_frameworks.get('tensorboard', True):
|
or auto_connect_frameworks.get('tensorboard', True):
|
||||||
PatchAbsl.update_current_task(Task.__main_task)
|
PatchAbsl.update_current_task(Task.__main_task)
|
||||||
TensorflowBinding.update_current_task(task,
|
TensorflowBinding.update_current_task(
|
||||||
is_auto_connect_frameworks_bool or
|
task,
|
||||||
auto_connect_frameworks.get('tensorflow', True),
|
patch_reporting=(is_auto_connect_frameworks_bool
|
||||||
is_auto_connect_frameworks_bool or
|
or auto_connect_frameworks.get('tensorboard', True)),
|
||||||
auto_connect_frameworks.get('tensorboard', True))
|
patch_model_io=(is_auto_connect_frameworks_bool
|
||||||
|
or auto_connect_frameworks.get('tensorflow', True)),
|
||||||
|
)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
||||||
PatchPyTorchModelIO.update_current_task(task)
|
PatchPyTorchModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||||
|
Loading…
Reference in New Issue
Block a user