From 0626d807a74b1f23069f555f8cc13ce29696b62a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 19 Jun 2020 19:09:17 +0300 Subject: [PATCH] Add WeightsFileHandler callback type enum --- trains/binding/frameworks/__init__.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index d1d9eacb..f10634a9 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -3,9 +3,10 @@ import shutil import sys import threading import weakref +from enum import Enum from random import randint from tempfile import mkstemp -from typing import TYPE_CHECKING, Callable, Dict, Optional, Any +from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Union import six from pathlib2 import Path @@ -56,6 +57,16 @@ class WeightsFileHandler(object): _model_pre_callbacks = {} _model_post_callbacks = {} + class CallbackType(Enum): + def __str__(self): + return str(self.value) + + def __eq__(self, other): + return str(self) == str(other) + + save = 'save' + load = 'load' + class ModelInfo(object): def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task): # type: (Optional[Model], Optional[str], str, str, str, Task) -> None @@ -99,7 +110,7 @@ class WeightsFileHandler(object): @classmethod def add_pre_callback(cls, callback_function): - # type: (Callable[[str, ModelInfo], Optional[ModelInfo]]) -> int + # type: (Callable[[Union[str, CallbackType], ModelInfo], Optional[ModelInfo]]) -> int """ Add a pre-save/load callback for weights files and return its handle. If the callback was already added, return the existing handle. @@ -117,7 +128,7 @@ class WeightsFileHandler(object): @classmethod def add_post_callback(cls, callback_function): - # type: (Callable[[str, ModelInfo], ModelInfo]) -> int + # type: (Callable[[Union[str, CallbackType], ModelInfo], ModelInfo]) -> int """ Add a post-save/load callback for weights files and return its handle. If the callback was already added, return the existing handle. @@ -165,7 +176,7 @@ class WeightsFileHandler(object): for cb in WeightsFileHandler._model_pre_callbacks.values(): # noinspection PyBroadException try: - model_info = cb('load', model_info) + model_info = cb(WeightsFileHandler.CallbackType.load, model_info) except Exception: pass @@ -238,7 +249,7 @@ class WeightsFileHandler(object): for cb in WeightsFileHandler._model_post_callbacks.values(): # noinspection PyBroadException try: - model_info = cb('load', model_info) + model_info = cb(WeightsFileHandler.CallbackType.load, model_info) except Exception: pass trains_in_model = model_info.model @@ -334,7 +345,7 @@ class WeightsFileHandler(object): for cb in WeightsFileHandler._model_pre_callbacks.values(): # noinspection PyBroadException try: - model_info = cb('save', model_info) + model_info = cb(WeightsFileHandler.CallbackType.save, model_info) except Exception: pass @@ -392,7 +403,7 @@ class WeightsFileHandler(object): for cb in WeightsFileHandler._model_post_callbacks.values(): # noinspection PyBroadException try: - model_info = cb('save', model_info) + model_info = cb(WeightsFileHandler.CallbackType.save, model_info) except Exception: pass trains_out_model = model_info.model