diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index 8b36ef9c..2443e811 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -114,14 +114,15 @@ class WeightsFileHandler(object): WeightsFileHandler._model_store_lookup_lock.acquire() # check if object already has InputModel - trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None)) + trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get( + id(model) if model is not None else None, (None, None)) if ref_model is not None and model != ref_model(): # old id pop it - it was probably reused because the object is dead WeightsFileHandler._model_in_store_lookup.pop(id(model)) trains_in_model, ref_model = None, None # check if object already has InputModel - model_name_id = getattr(model, 'name', '') + model_name_id = getattr(model, 'name', '') if model else '' # noinspection PyBroadException try: config_text = None @@ -149,7 +150,7 @@ class WeightsFileHandler(object): weights_url=filepath, config_dict=config_dict, config_text=config_text, - name=task.name + ' ' + model_name_id, + name=task.name + (' ' + model_name_id) if model_name_id else '', label_enumeration=task.get_labels_enumeration(), framework=framework, create_as_published=False, @@ -159,12 +160,13 @@ class WeightsFileHandler(object): for cb in WeightsFileHandler._model_post_callbacks.values(): trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task) - # noinspection PyBroadException - try: - ref_model = weakref.ref(model) - except Exception: - ref_model = None - WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model) + if model is not None: + # noinspection PyBroadException + try: + ref_model = weakref.ref(model) + except Exception: + ref_model = None + WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model) # todo: support multiple models for the same task task.connect(trains_in_model) # if we are running remotely we should deserialize the object @@ -196,7 +198,8 @@ class WeightsFileHandler(object): WeightsFileHandler._model_store_lookup_lock.acquire() # check if object already has InputModel - trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None)) + trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get( + id(model) if model is not None else None, (None, None)) # notice ref_model() is not an error/typo this is a weakref object call if ref_model is not None and model != ref_model(): # old id pop it - it was probably reused because the object is dead @@ -232,20 +235,24 @@ class WeightsFileHandler(object): # check if object already has InputModel if trains_out_model is None: - # if we are overwriting a local file, try to load registered model - # if there is an output_uri, then by definition we will not overwrite previously stored models. - if not task.output_uri: - try: - in_model_id = InputModel.load_model(weights_url=saved_path) - if in_model_id: - in_model_id = in_model_id.id - get_logger(TrainsFrameworkAdapter).info( - "Found existing registered model id={} [{}] reusing it.".format( - in_model_id, saved_path)) - except: + in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None)) + + if not in_model_id: + # if we are overwriting a local file, try to load registered model + # if there is an output_uri, then by definition we will not overwrite previously stored models. + if not task.output_uri: + try: + in_model_id = InputModel.load_model(weights_url=saved_path) + if in_model_id: + in_model_id = in_model_id.id + + get_logger(TrainsFrameworkAdapter).info( + "Found existing registered model id={} [{}] reusing it.".format( + in_model_id, saved_path)) + except: + in_model_id = None + else: in_model_id = None - else: - in_model_id = None trains_out_model = OutputModel( task=task, @@ -255,12 +262,13 @@ class WeightsFileHandler(object): framework=framework, base_model_id=in_model_id ) - # noinspection PyBroadException - try: - ref_model = weakref.ref(model) - except Exception: - ref_model = None - WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) + if model is not None: + # noinspection PyBroadException + try: + ref_model = weakref.ref(model) + except Exception: + ref_model = None + WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) # call post model callback functions for cb in WeightsFileHandler._model_post_callbacks.values():