diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index d7276b0d..075af390 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -167,6 +167,11 @@ class WeightsFileHandler(object): except Exception: pass + # if callback forced us to leave they return None + if model_info is None: + # callback forced quit + return filepath + if not model_info.local_model_path: get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged") return filepath @@ -324,6 +329,13 @@ class WeightsFileHandler(object): model_info = cb('save', model_info) except Exception: pass + + # if callbacks force us to leave they return None + if model_info is None: + # callback forced quit + return saved_path + + # update the trains_out_model after the pre callbacks trains_out_model = model_info.model # check if object already has InputModel