mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
This commit is contained in:
parent
deba24c689
commit
297990c454
@ -56,13 +56,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, upload_storage_uri, cache_dir, model_id=None, upload_storage_suffix="models", session=None, log=None
|
||||||
upload_storage_uri,
|
|
||||||
cache_dir,
|
|
||||||
model_id=None,
|
|
||||||
upload_storage_suffix="models",
|
|
||||||
session=None,
|
|
||||||
log=None
|
|
||||||
):
|
):
|
||||||
super(Model, self).__init__(id=model_id, session=session, log=log)
|
super(Model, self).__init__(id=model_id, session=session, log=log)
|
||||||
self._upload_storage_suffix = upload_storage_suffix
|
self._upload_storage_suffix = upload_storage_suffix
|
||||||
@ -84,9 +78,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
self.reload()
|
self.reload()
|
||||||
else:
|
else:
|
||||||
from ..model import BaseModel
|
from ..model import BaseModel
|
||||||
|
|
||||||
# edit will reload
|
# edit will reload
|
||||||
self._edit(
|
self._edit(
|
||||||
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) | {BaseModel._archived_tag})
|
system_tags=list(
|
||||||
|
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
|
||||||
|
| {BaseModel._archived_tag}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def unarchive(self):
|
def unarchive(self):
|
||||||
@ -95,9 +93,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
self.reload()
|
self.reload()
|
||||||
else:
|
else:
|
||||||
from ..model import BaseModel
|
from ..model import BaseModel
|
||||||
|
|
||||||
# edit will reload
|
# edit will reload
|
||||||
self._edit(
|
self._edit(
|
||||||
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) - {BaseModel._archived_tag})
|
system_tags=list(
|
||||||
|
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
|
||||||
|
- {BaseModel._archived_tag}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _reload(self):
|
def _reload(self):
|
||||||
@ -110,9 +112,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
res = self.send(models.GetByIdRequest(model=self.id))
|
res = self.send(models.GetByIdRequest(model=self.id))
|
||||||
return res.response.model
|
return res.response.model
|
||||||
|
|
||||||
def _upload_model(
|
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
|
||||||
self, model_file, async_enable=False, target_filename=None, cb=None
|
|
||||||
):
|
|
||||||
if not self.upload_storage_uri:
|
if not self.upload_storage_uri:
|
||||||
raise ValueError("Model has no storage URI defined (nowhere to upload to)")
|
raise ValueError("Model has no storage URI defined (nowhere to upload to)")
|
||||||
target_filename = target_filename or Path(model_file).name
|
target_filename = target_filename or Path(model_file).name
|
||||||
@ -128,15 +128,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
dest_path=dest_path,
|
dest_path=dest_path,
|
||||||
async_enable=async_enable,
|
async_enable=async_enable,
|
||||||
cb=partial(self._upload_callback, cb=cb),
|
cb=partial(self._upload_callback, cb=cb),
|
||||||
return_canonized=False
|
return_canonized=False,
|
||||||
)
|
)
|
||||||
if async_enable:
|
if async_enable:
|
||||||
|
|
||||||
def msg(num_results):
|
def msg(num_results):
|
||||||
self.log.info(
|
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
|
||||||
"Waiting for previous model to upload (%d pending, %s)"
|
|
||||||
% (num_results, dest_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
|
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
|
||||||
return dest_path
|
return dest_path
|
||||||
@ -206,9 +203,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
return list(design.values())[0]
|
return list(design.values())[0]
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError("design must be a string or a dictionary with at least one value")
|
||||||
"design must be a string or a dictionary with at least one value"
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@ -226,7 +221,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
upload_storage_uri=None,
|
upload_storage_uri=None,
|
||||||
target_filename=None,
|
target_filename=None,
|
||||||
iteration=None,
|
iteration=None,
|
||||||
system_tags=None
|
system_tags=None,
|
||||||
):
|
):
|
||||||
"""Update model weights file and various model properties"""
|
"""Update model weights file and various model properties"""
|
||||||
|
|
||||||
@ -241,11 +236,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||||
|
|
||||||
# upload model file if needed and get uri
|
# upload model file if needed and get uri
|
||||||
uri = uri or (
|
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
|
||||||
self._upload_model(model_file, target_filename=target_filename)
|
|
||||||
if model_file
|
|
||||||
else self.data.uri
|
|
||||||
)
|
|
||||||
# update fields
|
# update fields
|
||||||
design = self._wrap_design(design) if design else self.data.design
|
design = self._wrap_design(design) if design else self.data.design
|
||||||
name = name or self.data.name
|
name = name or self.data.name
|
||||||
@ -283,7 +274,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
uri=None,
|
uri=None,
|
||||||
framework=None,
|
framework=None,
|
||||||
iteration=None,
|
iteration=None,
|
||||||
system_tags=None
|
system_tags=None,
|
||||||
):
|
):
|
||||||
return self._edit(
|
return self._edit(
|
||||||
design=design,
|
design=design,
|
||||||
@ -369,7 +360,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
async_enable=False,
|
async_enable=False,
|
||||||
target_filename=None,
|
target_filename=None,
|
||||||
cb=None,
|
cb=None,
|
||||||
iteration=None
|
iteration=None,
|
||||||
):
|
):
|
||||||
"""Update the given model for a given task ID"""
|
"""Update the given model for a given task ID"""
|
||||||
if async_enable:
|
if async_enable:
|
||||||
@ -417,9 +408,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
)
|
)
|
||||||
return uri
|
return uri
|
||||||
else:
|
else:
|
||||||
uri = self._upload_model(
|
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
||||||
model_file, async_enable=async_enable, target_filename=target_filename
|
|
||||||
)
|
|
||||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||||
self.update(
|
self.update(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
@ -436,9 +425,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
return uri
|
return uri
|
||||||
|
|
||||||
def update_for_task(
|
def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None):
|
||||||
self, task_id, name=None, model_id=None, type_="output", iteration=None
|
|
||||||
):
|
|
||||||
if Session.check_min_api_version("2.13"):
|
if Session.check_min_api_version("2.13"):
|
||||||
req = tasks.AddOrUpdateModelRequest(
|
req = tasks.AddOrUpdateModelRequest(
|
||||||
task=task_id, name=name, type=type_, model=model_id, iteration=iteration
|
task=task_id, name=name, type=type_, model=model_id, iteration=iteration
|
||||||
@ -450,9 +437,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
# backwards compatibility, None
|
# backwards compatibility, None
|
||||||
req = None
|
req = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_))
|
||||||
"Type '{}' unsupported (use either 'input' or 'output')".format(type_)
|
|
||||||
)
|
|
||||||
|
|
||||||
if req:
|
if req:
|
||||||
self.send(req)
|
self.send(req)
|
||||||
@ -502,11 +487,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self):
|
def tags(self):
|
||||||
return (
|
return self.data.system_tags if hasattr(self.data, "system_tags") else self.data.tags
|
||||||
self.data.system_tags
|
|
||||||
if hasattr(self.data, "system_tags")
|
|
||||||
else self.data.tags
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task(self):
|
def task(self):
|
||||||
@ -549,11 +530,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# check if we already downloaded the file
|
# check if we already downloaded the file
|
||||||
downloaded_models = [
|
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
|
||||||
k
|
|
||||||
for k, (i, u) in Model._local_model_to_id_uri.items()
|
|
||||||
if i == self.id and u == uri
|
|
||||||
]
|
|
||||||
for dl_file in downloaded_models:
|
for dl_file in downloaded_models:
|
||||||
if Path(dl_file).exists() and not force_download:
|
if Path(dl_file).exists() and not force_download:
|
||||||
return dl_file
|
return dl_file
|
||||||
@ -593,9 +570,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
def get_model_package(self):
|
def get_model_package(self):
|
||||||
"""Get a named tuple containing the model's weights and design"""
|
"""Get a named tuple containing the model's weights and design"""
|
||||||
return ModelPackage(
|
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
|
||||||
weights=self.download_model_weights(), design=self.save_model_design_file()
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_model_design(self):
|
def get_model_design(self):
|
||||||
"""Get model description (text)"""
|
"""Get model description (text)"""
|
||||||
|
Loading…
Reference in New Issue
Block a user