From d4b11dfa22ddf22773211c4beb9b671dd2a0c71b Mon Sep 17 00:00:00 2001 From: Alex Burlacu Date: Fri, 11 Aug 2023 13:11:58 +0300 Subject: [PATCH] Create OutputModel base model lazily --- clearml/model.py | 146 +++++++++++++++++++++++++++++------------------ 1 file changed, 90 insertions(+), 56 deletions(-) diff --git a/clearml/model.py b/clearml/model.py index 7b9271c9..c380ee34 100644 --- a/clearml/model.py +++ b/clearml/model.py @@ -357,6 +357,9 @@ class BaseModel(object): self._task = None self._reload_required = False self._reporter = None + self._floating_data = None + self._name = None + self._task_connect_name = None self._set_task(task) def get_weights(self, raise_on_error=False, force_download=False): @@ -1055,6 +1058,7 @@ class BaseModel(object): def _init_reporter(self): if self._reporter: return + self._base_model = self._get_force_base_model() metrics_manager = Metrics( session=_Model._get_default_session(), storage_uri=None, @@ -1126,6 +1130,8 @@ class BaseModel(object): :return: True if the metadata was set and False otherwise """ + if not self._base_model: + self._base_model = self._get_force_base_model() self._reload_required = ( _Model._get_default_session() .send( @@ -1167,6 +1173,8 @@ class BaseModel(object): :return: String representation of the value of the metadata entry or None if the entry was not found """ + if not self._base_model: + self._base_model = self._get_force_base_model() self._reload_if_required() return self.get_all_metadata().get(str(key), {}).get("value") @@ -1180,6 +1188,8 @@ class BaseModel(object): :return: The value of the metadata entry, casted to its type (if not possible, the string representation will be returned) or None if the entry was not found """ + if not self._base_model: + self._base_model = self._get_force_base_model() key = str(key) metadata = self.get_all_metadata() if key not in metadata: @@ -1194,6 +1204,8 @@ class BaseModel(object): :return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key """ + if not self._base_model: + self._base_model = self._get_force_base_model() self._reload_if_required() return self._get_model_data().metadata or {} @@ -1204,6 +1216,8 @@ class BaseModel(object): entries are strings. The value is cast to its type if possible. Note that each entry might have an additional 'key' entry, repeating the key """ + if not self._base_model: + self._base_model = self._get_force_base_model() self._reload_if_required() result = {} metadata = self.get_all_metadata() @@ -1224,6 +1238,8 @@ class BaseModel(object): :return: True if the metadata was set and False otherwise """ + if not self._base_model: + self._base_model = self._get_force_base_model() metadata_array = [ { "key": str(k), @@ -1249,6 +1265,74 @@ class BaseModel(object): self._get_base_model().reload() self._reload_required = False + def _update_base_model(self, model_name=None, task_model_entry=None): + if not self._task: + return self._base_model + # update the model from the task inputs + labels = self._task.get_labels_enumeration() + # noinspection PyProtectedMember + config_text = self._task._get_model_config_text() + model_name = ( + model_name or self._name or (self._floating_data.name if self._floating_data else None) or self._task.name + ) + # noinspection PyBroadException + try: + task_model_entry = ( + task_model_entry + or self._task_connect_name + or Path(self._get_model_data().uri).stem + ) + except Exception: + pass + parent = self._task.input_models_id.get(task_model_entry) + self._base_model.update( + labels=(self._floating_data.labels if self._floating_data else None) or labels, + design=(self._floating_data.design if self._floating_data else None) or config_text, + task_id=self._task.id, + project_id=self._task.project, + parent_id=parent, + name=model_name, + comment=self._floating_data.comment if self._floating_data else None, + tags=self._floating_data.tags if self._floating_data else None, + framework=self._floating_data.framework if self._floating_data else None, + upload_storage_uri=self._floating_data.upload_storage_uri if self._floating_data else None, + ) + + # remove model floating change set, by now they should have matched the task. + self._floating_data = None + + # now we have to update the creator task so it points to us + if str(self._task.status) not in ( + str(self._task.TaskStatusEnum.created), + str(self._task.TaskStatusEnum.in_progress), + ): + self._log.warning( + "Could not update last created model in Task {}, " + "Task status '{}' cannot be updated".format( + self._task.id, self._task.status + ) + ) + elif task_model_entry: + self._base_model.update_for_task( + task_id=self._task.id, + model_id=self.id, + type_="output", + name=task_model_entry, + ) + + return self._base_model + + def _get_force_base_model(self, model_name=None, task_model_entry=None): + if self._base_model: + return self._base_model + if not self._task: + return None + + # create a new model from the task + # noinspection PyProtectedMember + self._base_model = self._task._get_output_model(model_id=None) + return self._update_base_model(model_name=model_name, task_model_entry=task_model_entry) + class Model(BaseModel): """ @@ -2060,6 +2144,7 @@ class OutputModel(BaseModel): self._base_model = None self._base_model_id = None self._task_connect_name = None + self._name = name self._label_enumeration = label_enumeration # noinspection PyProtectedMember self._floating_data = create_dummy_model( @@ -2300,7 +2385,11 @@ class OutputModel(BaseModel): if out_model_file_name else (self._task_connect_name or "Output Model") ) - model = self._get_force_base_model(task_model_entry=name) + if not self._base_model: + model = self._get_force_base_model(task_model_entry=name) + else: + self._update_base_model(task_model_entry=name) + model = self._base_model if not model: raise ValueError("Failed creating internal output model") @@ -2639,61 +2728,6 @@ class OutputModel(BaseModel): ) return weights_filename_offline or register_uri - def _get_force_base_model(self, model_name=None, task_model_entry=None): - if self._base_model: - return self._base_model - - # create a new model from the task - # noinspection PyProtectedMember - self._base_model = self._task._get_output_model(model_id=None) - # update the model from the task inputs - labels = self._task.get_labels_enumeration() - # noinspection PyProtectedMember - config_text = self._task._get_model_config_text() - model_name = model_name or self._floating_data.name or self._task.name - task_model_entry = ( - task_model_entry - or self._task_connect_name - or Path(self._get_model_data().uri).stem - ) - parent = self._task.input_models_id.get(task_model_entry) - self._base_model.update( - labels=self._floating_data.labels or labels, - design=self._floating_data.design or config_text, - task_id=self._task.id, - project_id=self._task.project, - parent_id=parent, - name=model_name, - comment=self._floating_data.comment, - tags=self._floating_data.tags, - framework=self._floating_data.framework, - upload_storage_uri=self._floating_data.upload_storage_uri, - ) - - # remove model floating change set, by now they should have matched the task. - self._floating_data = None - - # now we have to update the creator task so it points to us - if str(self._task.status) not in ( - str(self._task.TaskStatusEnum.created), - str(self._task.TaskStatusEnum.in_progress), - ): - self._log.warning( - "Could not update last created model in Task {}, " - "Task status '{}' cannot be updated".format( - self._task.id, self._task.status - ) - ) - else: - self._base_model.update_for_task( - task_id=self._task.id, - model_id=self.id, - type_="output", - name=task_model_entry, - ) - - return self._base_model - def _get_base_model(self): if self._floating_data: return self._floating_data