mirror of
https://github.com/clearml/clearml
synced 2025-02-07 21:33:25 +00:00
Add WeightsFileHandler callback type enum
This commit is contained in:
parent
c50574ca4d
commit
0626d807a7
@ -3,9 +3,10 @@ import shutil
|
||||
import sys
|
||||
import threading
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from random import randint
|
||||
from tempfile import mkstemp
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Union
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
@ -56,6 +57,16 @@ class WeightsFileHandler(object):
|
||||
_model_pre_callbacks = {}
|
||||
_model_post_callbacks = {}
|
||||
|
||||
class CallbackType(Enum):
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
||||
save = 'save'
|
||||
load = 'load'
|
||||
|
||||
class ModelInfo(object):
|
||||
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
|
||||
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None
|
||||
@ -99,7 +110,7 @@ class WeightsFileHandler(object):
|
||||
|
||||
@classmethod
|
||||
def add_pre_callback(cls, callback_function):
|
||||
# type: (Callable[[str, ModelInfo], Optional[ModelInfo]]) -> int
|
||||
# type: (Callable[[Union[str, CallbackType], ModelInfo], Optional[ModelInfo]]) -> int
|
||||
"""
|
||||
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
|
||||
return the existing handle.
|
||||
@ -117,7 +128,7 @@ class WeightsFileHandler(object):
|
||||
|
||||
@classmethod
|
||||
def add_post_callback(cls, callback_function):
|
||||
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int
|
||||
# type: (Callable[[Union[str, CallbackType], ModelInfo], ModelInfo]) -> int
|
||||
"""
|
||||
Add a post-save/load callback for weights files and return its handle.
|
||||
If the callback was already added, return the existing handle.
|
||||
@ -165,7 +176,7 @@ class WeightsFileHandler(object):
|
||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('load', model_info)
|
||||
model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -238,7 +249,7 @@ class WeightsFileHandler(object):
|
||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('load', model_info)
|
||||
model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
|
||||
except Exception:
|
||||
pass
|
||||
trains_in_model = model_info.model
|
||||
@ -334,7 +345,7 @@ class WeightsFileHandler(object):
|
||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('save', model_info)
|
||||
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -392,7 +403,7 @@ class WeightsFileHandler(object):
|
||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('save', model_info)
|
||||
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
|
||||
except Exception:
|
||||
pass
|
||||
trains_out_model = model_info.model
|
||||
|
Loading…
Reference in New Issue
Block a user