Add WeightsFileHandler callback type enum

This commit is contained in:
allegroai 2020-06-19 19:09:17 +03:00
parent c50574ca4d
commit 0626d807a7

View File

@ -3,9 +3,10 @@ import shutil
import sys
import threading
import weakref
from enum import Enum
from random import randint
from tempfile import mkstemp
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Union
import six
from pathlib2 import Path
@ -56,6 +57,16 @@ class WeightsFileHandler(object):
_model_pre_callbacks = {}
_model_post_callbacks = {}
class CallbackType(Enum):
def __str__(self):
return str(self.value)
def __eq__(self, other):
return str(self) == str(other)
save = 'save'
load = 'load'
class ModelInfo(object):
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None
@ -99,7 +110,7 @@ class WeightsFileHandler(object):
@classmethod
def add_pre_callback(cls, callback_function):
# type: (Callable[[str, ModelInfo], Optional[ModelInfo]]) -> int
# type: (Callable[[Union[str, CallbackType], ModelInfo], Optional[ModelInfo]]) -> int
"""
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
return the existing handle.
@ -117,7 +128,7 @@ class WeightsFileHandler(object):
@classmethod
def add_post_callback(cls, callback_function):
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int
# type: (Callable[[Union[str, CallbackType], ModelInfo], ModelInfo]) -> int
"""
Add a post-save/load callback for weights files and return its handle.
If the callback was already added, return the existing handle.
@ -165,7 +176,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_pre_callbacks.values():
# noinspection PyBroadException
try:
model_info = cb('load', model_info)
model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
except Exception:
pass
@ -238,7 +249,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_post_callbacks.values():
# noinspection PyBroadException
try:
model_info = cb('load', model_info)
model_info = cb(WeightsFileHandler.CallbackType.load, model_info)
except Exception:
pass
trains_in_model = model_info.model
@ -334,7 +345,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_pre_callbacks.values():
# noinspection PyBroadException
try:
model_info = cb('save', model_info)
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
except Exception:
pass
@ -392,7 +403,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_post_callbacks.values():
# noinspection PyBroadException
try:
model_info = cb('save', model_info)
model_info = cb(WeightsFileHandler.CallbackType.save, model_info)
except Exception:
pass
trains_out_model = model_info.model