mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
This commit is contained in:
parent
deba24c689
commit
297990c454
@ -56,13 +56,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return self.id
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upload_storage_uri,
|
||||
cache_dir,
|
||||
model_id=None,
|
||||
upload_storage_suffix="models",
|
||||
session=None,
|
||||
log=None
|
||||
self, upload_storage_uri, cache_dir, model_id=None, upload_storage_suffix="models", session=None, log=None
|
||||
):
|
||||
super(Model, self).__init__(id=model_id, session=session, log=log)
|
||||
self._upload_storage_suffix = upload_storage_suffix
|
||||
@ -84,9 +78,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
self.reload()
|
||||
else:
|
||||
from ..model import BaseModel
|
||||
|
||||
# edit will reload
|
||||
self._edit(
|
||||
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) | {BaseModel._archived_tag})
|
||||
system_tags=list(
|
||||
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
|
||||
| {BaseModel._archived_tag}
|
||||
)
|
||||
)
|
||||
|
||||
def unarchive(self):
|
||||
@ -95,9 +93,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
self.reload()
|
||||
else:
|
||||
from ..model import BaseModel
|
||||
|
||||
# edit will reload
|
||||
self._edit(
|
||||
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) - {BaseModel._archived_tag})
|
||||
system_tags=list(
|
||||
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
|
||||
- {BaseModel._archived_tag}
|
||||
)
|
||||
)
|
||||
|
||||
def _reload(self):
|
||||
@ -110,9 +112,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
res = self.send(models.GetByIdRequest(model=self.id))
|
||||
return res.response.model
|
||||
|
||||
def _upload_model(
|
||||
self, model_file, async_enable=False, target_filename=None, cb=None
|
||||
):
|
||||
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
|
||||
if not self.upload_storage_uri:
|
||||
raise ValueError("Model has no storage URI defined (nowhere to upload to)")
|
||||
target_filename = target_filename or Path(model_file).name
|
||||
@ -128,15 +128,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
dest_path=dest_path,
|
||||
async_enable=async_enable,
|
||||
cb=partial(self._upload_callback, cb=cb),
|
||||
return_canonized=False
|
||||
return_canonized=False,
|
||||
)
|
||||
if async_enable:
|
||||
|
||||
def msg(num_results):
|
||||
self.log.info(
|
||||
"Waiting for previous model to upload (%d pending, %s)"
|
||||
% (num_results, dest_path)
|
||||
)
|
||||
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
|
||||
|
||||
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
|
||||
return dest_path
|
||||
@ -206,9 +203,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
return list(design.values())[0]
|
||||
|
||||
raise ValueError(
|
||||
"design must be a string or a dictionary with at least one value"
|
||||
)
|
||||
raise ValueError("design must be a string or a dictionary with at least one value")
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -226,7 +221,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
upload_storage_uri=None,
|
||||
target_filename=None,
|
||||
iteration=None,
|
||||
system_tags=None
|
||||
system_tags=None,
|
||||
):
|
||||
"""Update model weights file and various model properties"""
|
||||
|
||||
@ -241,11 +236,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||
|
||||
# upload model file if needed and get uri
|
||||
uri = uri or (
|
||||
self._upload_model(model_file, target_filename=target_filename)
|
||||
if model_file
|
||||
else self.data.uri
|
||||
)
|
||||
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
|
||||
# update fields
|
||||
design = self._wrap_design(design) if design else self.data.design
|
||||
name = name or self.data.name
|
||||
@ -283,7 +274,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
uri=None,
|
||||
framework=None,
|
||||
iteration=None,
|
||||
system_tags=None
|
||||
system_tags=None,
|
||||
):
|
||||
return self._edit(
|
||||
design=design,
|
||||
@ -369,7 +360,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
async_enable=False,
|
||||
target_filename=None,
|
||||
cb=None,
|
||||
iteration=None
|
||||
iteration=None,
|
||||
):
|
||||
"""Update the given model for a given task ID"""
|
||||
if async_enable:
|
||||
@ -417,9 +408,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
)
|
||||
return uri
|
||||
else:
|
||||
uri = self._upload_model(
|
||||
model_file, async_enable=async_enable, target_filename=target_filename
|
||||
)
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||
self.update(
|
||||
uri=uri,
|
||||
@ -436,9 +425,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
return uri
|
||||
|
||||
def update_for_task(
|
||||
self, task_id, name=None, model_id=None, type_="output", iteration=None
|
||||
):
|
||||
def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None):
|
||||
if Session.check_min_api_version("2.13"):
|
||||
req = tasks.AddOrUpdateModelRequest(
|
||||
task=task_id, name=name, type=type_, model=model_id, iteration=iteration
|
||||
@ -450,9 +437,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
# backwards compatibility, None
|
||||
req = None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Type '{}' unsupported (use either 'input' or 'output')".format(type_)
|
||||
)
|
||||
raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_))
|
||||
|
||||
if req:
|
||||
self.send(req)
|
||||
@ -502,11 +487,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
return (
|
||||
self.data.system_tags
|
||||
if hasattr(self.data, "system_tags")
|
||||
else self.data.tags
|
||||
)
|
||||
return self.data.system_tags if hasattr(self.data, "system_tags") else self.data.tags
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
@ -549,11 +530,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return None
|
||||
|
||||
# check if we already downloaded the file
|
||||
downloaded_models = [
|
||||
k
|
||||
for k, (i, u) in Model._local_model_to_id_uri.items()
|
||||
if i == self.id and u == uri
|
||||
]
|
||||
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
|
||||
for dl_file in downloaded_models:
|
||||
if Path(dl_file).exists() and not force_download:
|
||||
return dl_file
|
||||
@ -593,9 +570,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
def get_model_package(self):
|
||||
"""Get a named tuple containing the model's weights and design"""
|
||||
return ModelPackage(
|
||||
weights=self.download_model_weights(), design=self.save_model_design_file()
|
||||
)
|
||||
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
|
||||
|
||||
def get_model_design(self):
|
||||
"""Get model description (text)"""
|
||||
|
Loading…
Reference in New Issue
Block a user