mirror of
https://github.com/clearml/clearml
synced 2025-03-03 02:32:11 +00:00
Add system_tags and tags to Model
This commit is contained in:
parent
5d20a0fa98
commit
1e4ab0510c
@ -166,7 +166,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
task_id=None, project_id=None, parent_id=None, uri=None, framework=None,
|
||||
upload_storage_uri=None, target_filename=None, iteration=None):
|
||||
upload_storage_uri=None, target_filename=None, iteration=None, system_tags=None):
|
||||
""" Update model weights file and various model properties """
|
||||
|
||||
if self.id is None:
|
||||
@ -189,6 +189,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
task = task_id or self.data.task
|
||||
project = project_id or self.data.project
|
||||
parent = parent_id or self.data.parent
|
||||
tags = tags or self.data.tags
|
||||
if Session.check_min_api_version('2.3'):
|
||||
system_tags = system_tags or self.data.system_tags
|
||||
|
||||
self._edit(
|
||||
uri=uri,
|
||||
@ -201,15 +204,17 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
task=task,
|
||||
project=project,
|
||||
parent=parent,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None):
|
||||
uri=None, framework=None, iteration=None, system_tags=None):
|
||||
return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
||||
uri=uri, framework=framework, iteration=iteration)
|
||||
uri=uri, framework=framework, iteration=iteration, system_tags=system_tags)
|
||||
|
||||
def _edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None, **extra):
|
||||
uri=None, framework=None, iteration=None, system_tags=None, **extra):
|
||||
def offline_store(**kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self.data, k, v or getattr(self.data, k, None))
|
||||
@ -218,9 +223,16 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
||||
uri=uri, framework=framework, iteration=iteration, **extra)
|
||||
|
||||
if tags:
|
||||
extra.update({'system_tags': tags or self.data.system_tags}
|
||||
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags})
|
||||
if Session.check_min_api_version('2.3'):
|
||||
if tags is not None:
|
||||
extra.update({'tags': tags})
|
||||
if system_tags is not None:
|
||||
extra.update({'system_tags': system_tags})
|
||||
elif tags is not None or system_tags is not None:
|
||||
if tags and system_tags:
|
||||
system_tags = system_tags[:]
|
||||
system_tags += [t for t in tags if t not in system_tags]
|
||||
extra.update({'system_tags': system_tags or tags or self.data.system_tags})
|
||||
|
||||
self.send(models.EditRequest(
|
||||
model=self.id,
|
||||
|
@ -176,6 +176,29 @@ class BaseModel(object):
|
||||
"""
|
||||
self._get_base_model().update(tags=value)
|
||||
|
||||
@property
|
||||
def system_tags(self):
|
||||
# type: () -> List[str]
|
||||
"""
|
||||
A list of system tags describing the model.
|
||||
|
||||
:return: The list of tags.
|
||||
"""
|
||||
data = self._get_model_data()
|
||||
return data.system_tags if Session.check_min_api_version('2.3') else data.tags
|
||||
|
||||
@system_tags.setter
|
||||
def system_tags(self, value):
|
||||
# type: (List[str]) -> None
|
||||
"""
|
||||
Set the list of system tags describing the model.
|
||||
|
||||
:param value: The tags.
|
||||
|
||||
:type value: list(str)
|
||||
"""
|
||||
self._get_base_model().update(system_tags=value)
|
||||
|
||||
@property
|
||||
def config_text(self):
|
||||
# type: () -> str
|
||||
@ -275,7 +298,7 @@ class BaseModel(object):
|
||||
:return: The model weights, or a list of the locally stored filenames.
|
||||
"""
|
||||
# check if model was packaged
|
||||
if self._package_tag not in self._get_model_data().tags:
|
||||
if not self._is_package():
|
||||
raise ValueError('Model is not packaged')
|
||||
|
||||
# download packaged model
|
||||
@ -331,9 +354,12 @@ class BaseModel(object):
|
||||
pass
|
||||
|
||||
def _set_package_tag(self):
|
||||
if self._package_tag not in self.tags:
|
||||
self.tags.append(self._package_tag)
|
||||
self._get_base_model().edit(tags=self.tags)
|
||||
if self._package_tag not in self.system_tags:
|
||||
self.system_tags.append(self._package_tag)
|
||||
self._get_base_model().edit(system_tags=self.system_tags)
|
||||
|
||||
def _is_package(self):
|
||||
return self._package_tag in (self.system_tags or [])
|
||||
|
||||
@staticmethod
|
||||
def _config_dict_to_text(config):
|
||||
@ -390,7 +416,7 @@ class Model(BaseModel):
|
||||
|
||||
:return: A local path to the model (or a downloaded copy of it).
|
||||
"""
|
||||
if extract_archive and self._package_tag in self.tags:
|
||||
if extract_archive and self._is_package():
|
||||
return self.get_weights_package(return_path=True, raise_on_error=raise_on_error)
|
||||
return self.get_weights(raise_on_error=raise_on_error)
|
||||
|
||||
@ -1033,7 +1059,8 @@ class OutputModel(BaseModel):
|
||||
auto_delete_file=True, # type: bool
|
||||
register_uri=None, # type: Optional[str]
|
||||
iteration=None, # type: Optional[int]
|
||||
update_comment=True # type: bool
|
||||
update_comment=True, # type: bool
|
||||
is_package=False, # type: bool
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
@ -1061,6 +1088,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
- ``True`` - Update model comment (Default)
|
||||
- ``False`` - Do not update
|
||||
:param bool is_package: Mark the weights file as compressed package, usually a zip file.
|
||||
|
||||
:return: The uploaded URI.
|
||||
"""
|
||||
@ -1151,6 +1179,9 @@ class OutputModel(BaseModel):
|
||||
else:
|
||||
output_uri = None
|
||||
|
||||
if is_package:
|
||||
self._set_package_tag()
|
||||
|
||||
# make sure that if we are in dev move we report that we are training (not debugging)
|
||||
self._task._output_model_updated()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user