Black formatting
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled

This commit is contained in:
clearml 2025-03-22 23:03:35 +02:00
parent deba24c689
commit 297990c454

View File

@ -56,13 +56,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return self.id
def __init__(
self,
upload_storage_uri,
cache_dir,
model_id=None,
upload_storage_suffix="models",
session=None,
log=None
self, upload_storage_uri, cache_dir, model_id=None, upload_storage_suffix="models", session=None, log=None
):
super(Model, self).__init__(id=model_id, session=session, log=log)
self._upload_storage_suffix = upload_storage_suffix
@ -84,9 +78,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
self.reload()
else:
from ..model import BaseModel
# edit will reload
self._edit(
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) | {BaseModel._archived_tag})
system_tags=list(
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
| {BaseModel._archived_tag}
)
)
def unarchive(self):
@ -95,9 +93,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
self.reload()
else:
from ..model import BaseModel
# edit will reload
self._edit(
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) - {BaseModel._archived_tag})
system_tags=list(
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
- {BaseModel._archived_tag}
)
)
def _reload(self):
@ -110,9 +112,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
res = self.send(models.GetByIdRequest(model=self.id))
return res.response.model
def _upload_model(
self, model_file, async_enable=False, target_filename=None, cb=None
):
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
if not self.upload_storage_uri:
raise ValueError("Model has no storage URI defined (nowhere to upload to)")
target_filename = target_filename or Path(model_file).name
@ -128,15 +128,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
dest_path=dest_path,
async_enable=async_enable,
cb=partial(self._upload_callback, cb=cb),
return_canonized=False
return_canonized=False,
)
if async_enable:
def msg(num_results):
self.log.info(
"Waiting for previous model to upload (%d pending, %s)"
% (num_results, dest_path)
)
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
return dest_path
@ -206,9 +203,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return list(design.values())[0]
raise ValueError(
"design must be a string or a dictionary with at least one value"
)
raise ValueError("design must be a string or a dictionary with at least one value")
def update(
self,
@ -226,7 +221,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
upload_storage_uri=None,
target_filename=None,
iteration=None,
system_tags=None
system_tags=None,
):
"""Update model weights file and various model properties"""
@ -241,11 +236,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
# upload model file if needed and get uri
uri = uri or (
self._upload_model(model_file, target_filename=target_filename)
if model_file
else self.data.uri
)
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
# update fields
design = self._wrap_design(design) if design else self.data.design
name = name or self.data.name
@ -283,7 +274,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
uri=None,
framework=None,
iteration=None,
system_tags=None
system_tags=None,
):
return self._edit(
design=design,
@ -369,7 +360,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
async_enable=False,
target_filename=None,
cb=None,
iteration=None
iteration=None,
):
"""Update the given model for a given task ID"""
if async_enable:
@ -417,9 +408,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
)
return uri
else:
uri = self._upload_model(
model_file, async_enable=async_enable, target_filename=target_filename
)
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
self.update(
uri=uri,
@ -436,9 +425,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return uri
def update_for_task(
self, task_id, name=None, model_id=None, type_="output", iteration=None
):
def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None):
if Session.check_min_api_version("2.13"):
req = tasks.AddOrUpdateModelRequest(
task=task_id, name=name, type=type_, model=model_id, iteration=iteration
@ -450,9 +437,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
# backwards compatibility, None
req = None
else:
raise ValueError(
"Type '{}' unsupported (use either 'input' or 'output')".format(type_)
)
raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_))
if req:
self.send(req)
@ -502,11 +487,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
@property
def tags(self):
return (
self.data.system_tags
if hasattr(self.data, "system_tags")
else self.data.tags
)
return self.data.system_tags if hasattr(self.data, "system_tags") else self.data.tags
@property
def task(self):
@ -549,11 +530,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return None
# check if we already downloaded the file
downloaded_models = [
k
for k, (i, u) in Model._local_model_to_id_uri.items()
if i == self.id and u == uri
]
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
for dl_file in downloaded_models:
if Path(dl_file).exists() and not force_download:
return dl_file
@ -593,9 +570,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def get_model_package(self):
"""Get a named tuple containing the model's weights and design"""
return ModelPackage(
weights=self.download_model_weights(), design=self.save_model_design_file()
)
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
def get_model_design(self):
"""Get model description (text)"""