mirror of
https://github.com/clearml/clearml
synced 2025-04-30 11:04:23 +00:00
Fix Model.connect() in remote execution might result in the wrong model being connected
This commit is contained in:
parent
e7f4497e36
commit
32832ae46d
@ -1088,12 +1088,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
if not model.ready:
|
if not model.ready:
|
||||||
# raise ValueError('Model %s is not published (not ready)' % model_id)
|
# raise ValueError('Model %s is not published (not ready)' % model_id)
|
||||||
self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
|
self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
|
||||||
name = name or Path(model.uri).stem
|
|
||||||
else:
|
else:
|
||||||
# clear the input model
|
# clear the input model
|
||||||
model = None
|
model = None
|
||||||
model_id = ''
|
model_id = ''
|
||||||
name = name or 'Input Model'
|
from ...model import InputModel
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
name = name or InputModel._get_connect_name(model)
|
||||||
|
|
||||||
with self._edit_lock:
|
with self._edit_lock:
|
||||||
self.reload()
|
self.reload()
|
||||||
|
@ -1595,6 +1595,7 @@ class InputModel(Model):
|
|||||||
|
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
|
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
|
||||||
|
_WARNING_CONNECTED_NAMES = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def import_model(
|
def import_model(
|
||||||
@ -1932,9 +1933,11 @@ class InputModel(Model):
|
|||||||
|
|
||||||
:param object task: A Task object.
|
:param object task: A Task object.
|
||||||
:param str name: The model name to be stored on the Task
|
:param str name: The model name to be stored on the Task
|
||||||
(default the filename, of the model weights, without the file extension)
|
(default to filename of the model weights, without the file extension, or to `Input Model` if that is not found)
|
||||||
"""
|
"""
|
||||||
self._set_task(task)
|
self._set_task(task)
|
||||||
|
name = name or InputModel._get_connect_name(self)
|
||||||
|
InputModel._warn_on_same_name_connect(name)
|
||||||
|
|
||||||
model_id = None
|
model_id = None
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
@ -1966,6 +1969,28 @@ class InputModel(Model):
|
|||||||
if not self._task.get_labels_enumeration() and model.data.labels:
|
if not self._task.get_labels_enumeration() and model.data.labels:
|
||||||
task.set_model_label_enumeration(model.data.labels)
|
task.set_model_label_enumeration(model.data.labels)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _warn_on_same_name_connect(cls, name):
|
||||||
|
if name not in cls._WARNING_CONNECTED_NAMES:
|
||||||
|
cls._WARNING_CONNECTED_NAMES[name] = False
|
||||||
|
return
|
||||||
|
if cls._WARNING_CONNECTED_NAMES[name]:
|
||||||
|
return
|
||||||
|
get_logger().warning("Connecting multiple input models with the same name: `{}`. This might result in the wrong model being used when executing remotely".format(name))
|
||||||
|
cls._WARNING_CONNECTED_NAMES[name] = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_connect_name(model):
|
||||||
|
default_name = "Input Model"
|
||||||
|
if model is None:
|
||||||
|
return default_name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_uri = getattr(model, "url", getattr(model, "uri", None))
|
||||||
|
return Path(model_uri).stem
|
||||||
|
except Exception:
|
||||||
|
return default_name
|
||||||
|
|
||||||
|
|
||||||
class OutputModel(BaseModel):
|
class OutputModel(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user