mirror of
https://github.com/clearml/clearml
synced 2025-04-25 16:59:46 +00:00
Fix Model.connect() in remote execution might result in the wrong model being connected
This commit is contained in:
parent
e7f4497e36
commit
32832ae46d
@ -1088,12 +1088,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if not model.ready:
|
||||
# raise ValueError('Model %s is not published (not ready)' % model_id)
|
||||
self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
|
||||
name = name or Path(model.uri).stem
|
||||
else:
|
||||
# clear the input model
|
||||
model = None
|
||||
model_id = ''
|
||||
name = name or 'Input Model'
|
||||
from ...model import InputModel
|
||||
# noinspection PyProtectedMember
|
||||
name = name or InputModel._get_connect_name(model)
|
||||
|
||||
with self._edit_lock:
|
||||
self.reload()
|
||||
|
@ -1595,6 +1595,7 @@ class InputModel(Model):
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
|
||||
_WARNING_CONNECTED_NAMES = {}
|
||||
|
||||
@classmethod
|
||||
def import_model(
|
||||
@ -1932,9 +1933,11 @@ class InputModel(Model):
|
||||
|
||||
:param object task: A Task object.
|
||||
:param str name: The model name to be stored on the Task
|
||||
(default the filename, of the model weights, without the file extension)
|
||||
(default to filename of the model weights, without the file extension, or to `Input Model` if that is not found)
|
||||
"""
|
||||
self._set_task(task)
|
||||
name = name or InputModel._get_connect_name(self)
|
||||
InputModel._warn_on_same_name_connect(name)
|
||||
|
||||
model_id = None
|
||||
# noinspection PyProtectedMember
|
||||
@ -1966,6 +1969,28 @@ class InputModel(Model):
|
||||
if not self._task.get_labels_enumeration() and model.data.labels:
|
||||
task.set_model_label_enumeration(model.data.labels)
|
||||
|
||||
@classmethod
|
||||
def _warn_on_same_name_connect(cls, name):
|
||||
if name not in cls._WARNING_CONNECTED_NAMES:
|
||||
cls._WARNING_CONNECTED_NAMES[name] = False
|
||||
return
|
||||
if cls._WARNING_CONNECTED_NAMES[name]:
|
||||
return
|
||||
get_logger().warning("Connecting multiple input models with the same name: `{}`. This might result in the wrong model being used when executing remotely".format(name))
|
||||
cls._WARNING_CONNECTED_NAMES[name] = True
|
||||
|
||||
@staticmethod
|
||||
def _get_connect_name(model):
|
||||
default_name = "Input Model"
|
||||
if model is None:
|
||||
return default_name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_uri = getattr(model, "url", getattr(model, "uri", None))
|
||||
return Path(model_uri).stem
|
||||
except Exception:
|
||||
return default_name
|
||||
|
||||
|
||||
class OutputModel(BaseModel):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user