diff --git a/trains/backend_interface/model.py b/trains/backend_interface/model.py index 2720e229..98283eba 100644 --- a/trains/backend_interface/model.py +++ b/trains/backend_interface/model.py @@ -166,7 +166,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None, task_id=None, project_id=None, parent_id=None, uri=None, framework=None, - upload_storage_uri=None, target_filename=None, iteration=None): + upload_storage_uri=None, target_filename=None, iteration=None, system_tags=None): """ Update model weights file and various model properties """ if self.id is None: @@ -189,6 +189,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): task = task_id or self.data.task project = project_id or self.data.project parent = parent_id or self.data.parent + tags = tags or self.data.tags + if Session.check_min_api_version('2.3'): + system_tags = system_tags or self.data.system_tags self._edit( uri=uri, @@ -201,15 +204,17 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): task=task, project=project, parent=parent, + tags=tags, + system_tags=system_tags, ) def edit(self, design=None, labels=None, name=None, comment=None, tags=None, - uri=None, framework=None, iteration=None): + uri=None, framework=None, iteration=None, system_tags=None): return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags, - uri=uri, framework=framework, iteration=iteration) + uri=uri, framework=framework, iteration=iteration, system_tags=system_tags) def _edit(self, design=None, labels=None, name=None, comment=None, tags=None, - uri=None, framework=None, iteration=None, **extra): + uri=None, framework=None, iteration=None, system_tags=None, **extra): def offline_store(**kwargs): for k, v in kwargs.items(): setattr(self.data, k, v or getattr(self.data, k, None)) @@ -218,9 +223,16 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags, uri=uri, framework=framework, iteration=iteration, **extra) - if tags: - extra.update({'system_tags': tags or self.data.system_tags} - if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}) + if Session.check_min_api_version('2.3'): + if tags is not None: + extra.update({'tags': tags}) + if system_tags is not None: + extra.update({'system_tags': system_tags}) + elif tags is not None or system_tags is not None: + if tags and system_tags: + system_tags = system_tags[:] + system_tags += [t for t in tags if t not in system_tags] + extra.update({'system_tags': system_tags or tags or self.data.system_tags}) self.send(models.EditRequest( model=self.id, diff --git a/trains/model.py b/trains/model.py index 05dc709c..a5a8938c 100644 --- a/trains/model.py +++ b/trains/model.py @@ -176,6 +176,29 @@ class BaseModel(object): """ self._get_base_model().update(tags=value) + @property + def system_tags(self): + # type: () -> List[str] + """ + A list of system tags describing the model. + + :return: The list of tags. + """ + data = self._get_model_data() + return data.system_tags if Session.check_min_api_version('2.3') else data.tags + + @system_tags.setter + def system_tags(self, value): + # type: (List[str]) -> None + """ + Set the list of system tags describing the model. + + :param value: The tags. + + :type value: list(str) + """ + self._get_base_model().update(system_tags=value) + @property def config_text(self): # type: () -> str @@ -275,7 +298,7 @@ class BaseModel(object): :return: The model weights, or a list of the locally stored filenames. """ # check if model was packaged - if self._package_tag not in self._get_model_data().tags: + if not self._is_package(): raise ValueError('Model is not packaged') # download packaged model @@ -331,9 +354,12 @@ class BaseModel(object): pass def _set_package_tag(self): - if self._package_tag not in self.tags: - self.tags.append(self._package_tag) - self._get_base_model().edit(tags=self.tags) + if self._package_tag not in self.system_tags: + self.system_tags.append(self._package_tag) + self._get_base_model().edit(system_tags=self.system_tags) + + def _is_package(self): + return self._package_tag in (self.system_tags or []) @staticmethod def _config_dict_to_text(config): @@ -390,7 +416,7 @@ class Model(BaseModel): :return: A local path to the model (or a downloaded copy of it). """ - if extract_archive and self._package_tag in self.tags: + if extract_archive and self._is_package(): return self.get_weights_package(return_path=True, raise_on_error=raise_on_error) return self.get_weights(raise_on_error=raise_on_error) @@ -1033,7 +1059,8 @@ class OutputModel(BaseModel): auto_delete_file=True, # type: bool register_uri=None, # type: Optional[str] iteration=None, # type: Optional[int] - update_comment=True # type: bool + update_comment=True, # type: bool + is_package=False, # type: bool ): # type: (...) -> str """ @@ -1061,6 +1088,7 @@ class OutputModel(BaseModel): - ``True`` - Update model comment (Default) - ``False`` - Do not update + :param bool is_package: Mark the weights file as compressed package, usually a zip file. :return: The uploaded URI. """ @@ -1151,6 +1179,9 @@ class OutputModel(BaseModel): else: output_uri = None + if is_package: + self._set_package_tag() + # make sure that if we are in dev move we report that we are training (not debugging) self._task._output_model_updated()