mirror of
https://github.com/clearml/clearml
synced 2025-02-12 07:35:08 +00:00
Add WeightsFileHandler callback type enum
This commit is contained in:
parent
c50574ca4d
commit
0626d807a7
@ -3,9 +3,10 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
|
from enum import Enum
|
||||||
from random import randint
|
from random import randint
|
||||||
from tempfile import mkstemp
|
from tempfile import mkstemp
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Union
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
@ -56,6 +57,16 @@ class WeightsFileHandler(object):
|
|||||||
_model_pre_callbacks = {}
|
_model_pre_callbacks = {}
|
||||||
_model_post_callbacks = {}
|
_model_post_callbacks = {}
|
||||||
|
|
||||||
|
class CallbackType(Enum):
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return str(self) == str(other)
|
||||||
|
|
||||||
|
save = 'save'
|
||||||
|
load = 'load'
|
||||||
|
|
||||||
class ModelInfo(object):
|
class ModelInfo(object):
|
||||||
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
|
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
|
||||||
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None
|
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None
|
||||||
@ -99,7 +110,7 @@ class WeightsFileHandler(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_pre_callback(cls, callback_function):
|
def add_pre_callback(cls, callback_function):
|
||||||
# type: (Callable[[str, ModelInfo], Optional[ModelInfo]]) -> int
|
# type: (Callable[[Union[str, CallbackType], ModelInfo], Optional[ModelInfo]]) -> int
|
||||||
"""
|
"""
|
||||||
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
|
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
|
||||||
return the existing handle.
|
return the existing handle.
|
||||||
@ -117,7 +128,7 @@ class WeightsFileHandler(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_post_callback(cls, callback_function):
|
def add_post_callback(cls, callback_function):
|
||||||
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int
|
# type: (Callable[[Union[str, CallbackType], ModelInfo], ModelInfo]) -> int
|
||||||
"""
|
"""
|
||||||
Add a post-save/load callback for weights files and return its handle.
|
Add a post-save/load callback for weights files and return its handle.
|
||||||
If the callback was already added, return the existing handle.
|
If the callback was already added, return the existing handle.
|
||||||
@ -165,7 +176,7 @@ class WeightsFileHandler(object):
|
|||||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model_info = cb('load', model_info)
|
model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -238,7 +249,7 @@ class WeightsFileHandler(object):
|
|||||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model_info = cb('load', model_info)
|
model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
trains_in_model = model_info.model
|
trains_in_model = model_info.model
|
||||||
@ -334,7 +345,7 @@ class WeightsFileHandler(object):
|
|||||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model_info = cb('save', model_info)
|
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -392,7 +403,7 @@ class WeightsFileHandler(object):
|
|||||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model_info = cb('save', model_info)
|
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
trains_out_model = model_info.model
|
trains_out_model = model_info.model
|
||||||
|
Loading…
Reference in New Issue
Block a user