mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Fix Model.connect() in remote execution might result in the wrong model being connected
This commit is contained in:
		
							parent
							
								
									e7f4497e36
								
							
						
					
					
						commit
						32832ae46d
					
				@ -1088,12 +1088,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
 | 
			
		||||
            if not model.ready:
 | 
			
		||||
                # raise ValueError('Model %s is not published (not ready)' % model_id)
 | 
			
		||||
                self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
 | 
			
		||||
            name = name or Path(model.uri).stem
 | 
			
		||||
        else:
 | 
			
		||||
            # clear the input model
 | 
			
		||||
            model = None
 | 
			
		||||
            model_id = ''
 | 
			
		||||
            name = name or 'Input Model'
 | 
			
		||||
        from ...model import InputModel
 | 
			
		||||
        # noinspection PyProtectedMember
 | 
			
		||||
        name = name or InputModel._get_connect_name(model)
 | 
			
		||||
 | 
			
		||||
        with self._edit_lock:
 | 
			
		||||
            self.reload()
 | 
			
		||||
 | 
			
		||||
@ -1595,6 +1595,7 @@ class InputModel(Model):
 | 
			
		||||
 | 
			
		||||
    # noinspection PyProtectedMember
 | 
			
		||||
    _EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
 | 
			
		||||
    _WARNING_CONNECTED_NAMES = {}
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def import_model(
 | 
			
		||||
@ -1932,9 +1933,11 @@ class InputModel(Model):
 | 
			
		||||
 | 
			
		||||
        :param object task: A Task object.
 | 
			
		||||
        :param str name: The model name to be stored on the Task
 | 
			
		||||
            (default the filename, of the model weights, without the file extension)
 | 
			
		||||
            (default to filename of the model weights, without the file extension, or to `Input Model` if that is not found)
 | 
			
		||||
        """
 | 
			
		||||
        self._set_task(task)
 | 
			
		||||
        name = name or InputModel._get_connect_name(self)
 | 
			
		||||
        InputModel._warn_on_same_name_connect(name)
 | 
			
		||||
 | 
			
		||||
        model_id = None
 | 
			
		||||
        # noinspection PyProtectedMember
 | 
			
		||||
@ -1966,6 +1969,28 @@ class InputModel(Model):
 | 
			
		||||
            if not self._task.get_labels_enumeration() and model.data.labels:
 | 
			
		||||
                task.set_model_label_enumeration(model.data.labels)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _warn_on_same_name_connect(cls, name):
 | 
			
		||||
        if name not in cls._WARNING_CONNECTED_NAMES:
 | 
			
		||||
            cls._WARNING_CONNECTED_NAMES[name] = False
 | 
			
		||||
            return
 | 
			
		||||
        if cls._WARNING_CONNECTED_NAMES[name]:
 | 
			
		||||
            return
 | 
			
		||||
        get_logger().warning("Connecting multiple input models with the same name: `{}`. This might result in the wrong model being used when executing remotely".format(name))
 | 
			
		||||
        cls._WARNING_CONNECTED_NAMES[name] = True
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get_connect_name(model):
 | 
			
		||||
        default_name = "Input Model"
 | 
			
		||||
        if model is None:
 | 
			
		||||
            return default_name
 | 
			
		||||
        # noinspection PyBroadException
 | 
			
		||||
        try:
 | 
			
		||||
            model_uri = getattr(model, "url", getattr(model, "uri", None))
 | 
			
		||||
            return Path(model_uri).stem
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return default_name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OutputModel(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user