Fix Model.connect() in remote execution might result in the wrong model being connected

This commit is contained in:
allegroai 2023-11-05 21:04:53 +02:00
parent e7f4497e36
commit 32832ae46d
2 changed files with 29 additions and 3 deletions

View File

@ -1088,12 +1088,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not model.ready:
# raise ValueError('Model %s is not published (not ready)' % model_id)
self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
name = name or Path(model.uri).stem
else:
# clear the input model
model = None
model_id = ''
name = name or 'Input Model'
from ...model import InputModel
# noinspection PyProtectedMember
name = name or InputModel._get_connect_name(model)
with self._edit_lock:
self.reload()

View File

@ -1595,6 +1595,7 @@ class InputModel(Model):
# noinspection PyProtectedMember
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
_WARNING_CONNECTED_NAMES = {}
@classmethod
def import_model(
@ -1932,9 +1933,11 @@ class InputModel(Model):
:param object task: A Task object.
:param str name: The model name to be stored on the Task
(default the filename, of the model weights, without the file extension)
(default to filename of the model weights, without the file extension, or to `Input Model` if that is not found)
"""
self._set_task(task)
name = name or InputModel._get_connect_name(self)
InputModel._warn_on_same_name_connect(name)
model_id = None
# noinspection PyProtectedMember
@ -1966,6 +1969,28 @@ class InputModel(Model):
if not self._task.get_labels_enumeration() and model.data.labels:
task.set_model_label_enumeration(model.data.labels)
@classmethod
def _warn_on_same_name_connect(cls, name):
if name not in cls._WARNING_CONNECTED_NAMES:
cls._WARNING_CONNECTED_NAMES[name] = False
return
if cls._WARNING_CONNECTED_NAMES[name]:
return
get_logger().warning("Connecting multiple input models with the same name: `{}`. This might result in the wrong model being used when executing remotely".format(name))
cls._WARNING_CONNECTED_NAMES[name] = True
@staticmethod
def _get_connect_name(model):
default_name = "Input Model"
if model is None:
return default_name
# noinspection PyBroadException
try:
model_uri = getattr(model, "url", getattr(model, "uri", None))
return Path(model_uri).stem
except Exception:
return default_name
class OutputModel(BaseModel):
"""