Add WeightsFileHandler callback type enum

This commit is contained in:
allegroai 2020-06-19 19:09:17 +03:00
parent c50574ca4d
commit 0626d807a7

View File

@ -3,9 +3,10 @@ import shutil
import sys import sys
import threading import threading
import weakref import weakref
from enum import Enum
from random import randint from random import randint
from tempfile import mkstemp from tempfile import mkstemp
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Union
import six import six
from pathlib2 import Path from pathlib2 import Path
@ -56,6 +57,16 @@ class WeightsFileHandler(object):
_model_pre_callbacks = {} _model_pre_callbacks = {}
_model_post_callbacks = {} _model_post_callbacks = {}
class CallbackType(Enum):
def __str__(self):
return str(self.value)
def __eq__(self, other):
return str(self) == str(other)
save = 'save'
load = 'load'
class ModelInfo(object): class ModelInfo(object):
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task): def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None # type: (Optional[Model], Optional[str], str, str, str, Task) -> None
@ -99,7 +110,7 @@ class WeightsFileHandler(object):
@classmethod @classmethod
def add_pre_callback(cls, callback_function): def add_pre_callback(cls, callback_function):
# type: (Callable[[str, ModelInfo], Optional[ModelInfo]]) -> int # type: (Callable[[Union[str, CallbackType], ModelInfo], Optional[ModelInfo]]) -> int
""" """
Add a pre-save/load callback for weights files and return its handle. If the callback was already added, Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
return the existing handle. return the existing handle.
@ -117,7 +128,7 @@ class WeightsFileHandler(object):
@classmethod @classmethod
def add_post_callback(cls, callback_function): def add_post_callback(cls, callback_function):
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int # type: (Callable[[Union[str, CallbackType], ModelInfo], ModelInfo]) -> int
""" """
Add a post-save/load callback for weights files and return its handle. Add a post-save/load callback for weights files and return its handle.
If the callback was already added, return the existing handle. If the callback was already added, return the existing handle.
@ -165,7 +176,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_pre_callbacks.values(): for cb in WeightsFileHandler._model_pre_callbacks.values():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
model_info = cb('load', model_info) model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
except Exception: except Exception:
pass pass
@ -238,7 +249,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_post_callbacks.values(): for cb in WeightsFileHandler._model_post_callbacks.values():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
model_info = cb('load', model_info) model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
except Exception: except Exception:
pass pass
trains_in_model = model_info.model trains_in_model = model_info.model
@ -334,7 +345,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_pre_callbacks.values(): for cb in WeightsFileHandler._model_pre_callbacks.values():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
model_info = cb('save', model_info) model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
except Exception: except Exception:
pass pass
@ -392,7 +403,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_post_callbacks.values(): for cb in WeightsFileHandler._model_post_callbacks.values():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
model_info = cb('save', model_info) model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
except Exception: except Exception:
pass pass
trains_out_model = model_info.model trains_out_model = model_info.model