mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
Rename TF patching functions for better readability
This commit is contained in:
parent
dc6c8cfddc
commit
a9f52a468c
@ -26,14 +26,17 @@ except ImportError:
|
||||
|
||||
class TensorflowBinding(object):
|
||||
@classmethod
|
||||
def update_current_task(cls, task, save_models, report_tensorboard):
|
||||
def update_current_task(cls, task, patch_reporting=True, patch_model_io=True):
|
||||
if not task:
|
||||
IsTensorboardInit.clear_tensorboard_used()
|
||||
if report_tensorboard:
|
||||
EventTrainsWriter.update_current_task(task)
|
||||
|
||||
EventTrainsWriter.update_current_task(task)
|
||||
|
||||
if patch_reporting:
|
||||
PatchSummaryToEventTransformer.update_current_task(task)
|
||||
PatchTensorFlowEager.update_current_task(task)
|
||||
if save_models:
|
||||
|
||||
if patch_model_io:
|
||||
PatchKerasModelIO.update_current_task(task)
|
||||
PatchTensorflowModelIO.update_current_task(task)
|
||||
PatchTensorflow2ModelIO.update_current_task(task)
|
||||
@ -1145,11 +1148,11 @@ class PatchTensorFlowEager(object):
|
||||
PatchTensorFlowEager.defaults_dict.update(kwargs)
|
||||
PatchTensorFlowEager.__main_task = task
|
||||
# make sure we patched the SummaryToEventTransformer
|
||||
PatchTensorFlowEager._patch_model_checkpoint()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_model_checkpoint)
|
||||
PatchTensorFlowEager._patch_summary_ops()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_checkpoint():
|
||||
def _patch_summary_ops():
|
||||
if PatchTensorFlowEager.__original_fn_scalar is not None:
|
||||
return
|
||||
if 'tensorflow' in sys.modules:
|
||||
|
@ -545,11 +545,13 @@ class Task(_Task):
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('tensorflow', True) \
|
||||
or auto_connect_frameworks.get('tensorboard', True):
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
TensorflowBinding.update_current_task(task,
|
||||
is_auto_connect_frameworks_bool or
|
||||
auto_connect_frameworks.get('tensorflow', True),
|
||||
is_auto_connect_frameworks_bool or
|
||||
auto_connect_frameworks.get('tensorboard', True))
|
||||
TensorflowBinding.update_current_task(
|
||||
task,
|
||||
patch_reporting=(is_auto_connect_frameworks_bool
|
||||
or auto_connect_frameworks.get('tensorboard', True)),
|
||||
patch_model_io=(is_auto_connect_frameworks_bool
|
||||
or auto_connect_frameworks.get('tensorflow', True)),
|
||||
)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('pytorch', True):
|
||||
PatchPyTorchModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||
|
Loading…
Reference in New Issue
Block a user