Add system_tags and tags to Model

This commit is contained in:
allegroai 2020-10-12 10:50:38 +03:00
parent 5d20a0fa98
commit 1e4ab0510c
2 changed files with 56 additions and 13 deletions

View File

@ -166,7 +166,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None,
task_id=None, project_id=None, parent_id=None, uri=None, framework=None,
upload_storage_uri=None, target_filename=None, iteration=None):
upload_storage_uri=None, target_filename=None, iteration=None, system_tags=None):
""" Update model weights file and various model properties """
if self.id is None:
@ -189,6 +189,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
task = task_id or self.data.task
project = project_id or self.data.project
parent = parent_id or self.data.parent
tags = tags or self.data.tags
if Session.check_min_api_version('2.3'):
system_tags = system_tags or self.data.system_tags
self._edit(
uri=uri,
@ -201,15 +204,17 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
task=task,
project=project,
parent=parent,
tags=tags,
system_tags=system_tags,
)
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
uri=None, framework=None, iteration=None):
uri=None, framework=None, iteration=None, system_tags=None):
return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags,
uri=uri, framework=framework, iteration=iteration)
uri=uri, framework=framework, iteration=iteration, system_tags=system_tags)
def _edit(self, design=None, labels=None, name=None, comment=None, tags=None,
uri=None, framework=None, iteration=None, **extra):
uri=None, framework=None, iteration=None, system_tags=None, **extra):
def offline_store(**kwargs):
for k, v in kwargs.items():
setattr(self.data, k, v or getattr(self.data, k, None))
@ -218,9 +223,16 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags,
uri=uri, framework=framework, iteration=iteration, **extra)
if tags:
extra.update({'system_tags': tags or self.data.system_tags}
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags})
if Session.check_min_api_version('2.3'):
if tags is not None:
extra.update({'tags': tags})
if system_tags is not None:
extra.update({'system_tags': system_tags})
elif tags is not None or system_tags is not None:
if tags and system_tags:
system_tags = system_tags[:]
system_tags += [t for t in tags if t not in system_tags]
extra.update({'system_tags': system_tags or tags or self.data.system_tags})
self.send(models.EditRequest(
model=self.id,

View File

@ -176,6 +176,29 @@ class BaseModel(object):
"""
self._get_base_model().update(tags=value)
@property
def system_tags(self):
# type: () -> List[str]
"""
A list of system tags describing the model.
:return: The list of tags.
"""
data = self._get_model_data()
return data.system_tags if Session.check_min_api_version('2.3') else data.tags
@system_tags.setter
def system_tags(self, value):
# type: (List[str]) -> None
"""
Set the list of system tags describing the model.
:param value: The tags.
:type value: list(str)
"""
self._get_base_model().update(system_tags=value)
@property
def config_text(self):
# type: () -> str
@ -275,7 +298,7 @@ class BaseModel(object):
:return: The model weights, or a list of the locally stored filenames.
"""
# check if model was packaged
if self._package_tag not in self._get_model_data().tags:
if not self._is_package():
raise ValueError('Model is not packaged')
# download packaged model
@ -331,9 +354,12 @@ class BaseModel(object):
pass
def _set_package_tag(self):
if self._package_tag not in self.tags:
self.tags.append(self._package_tag)
self._get_base_model().edit(tags=self.tags)
if self._package_tag not in self.system_tags:
self.system_tags.append(self._package_tag)
self._get_base_model().edit(system_tags=self.system_tags)
def _is_package(self):
return self._package_tag in (self.system_tags or [])
@staticmethod
def _config_dict_to_text(config):
@ -390,7 +416,7 @@ class Model(BaseModel):
:return: A local path to the model (or a downloaded copy of it).
"""
if extract_archive and self._package_tag in self.tags:
if extract_archive and self._is_package():
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
return self.get_weights(raise_on_error=raise_on_error)
@ -1033,7 +1059,8 @@ class OutputModel(BaseModel):
auto_delete_file=True, # type: bool
register_uri=None, # type: Optional[str]
iteration=None, # type: Optional[int]
update_comment=True # type: bool
update_comment=True, # type: bool
is_package=False, # type: bool
):
# type: (...) -> str
"""
@ -1061,6 +1088,7 @@ class OutputModel(BaseModel):
- ``True`` - Update model comment (Default)
- ``False`` - Do not update
:param bool is_package: Mark the weights file as compressed package, usually a zip file.
:return: The uploaded URI.
"""
@ -1151,6 +1179,9 @@ class OutputModel(BaseModel):
else:
output_uri = None
if is_package:
self._set_package_tag()
# make sure that if we are in dev move we report that we are training (not debugging)
self._task._output_model_updated()