Black formatting
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled

This commit is contained in:
clearml 2025-03-22 23:03:35 +02:00
parent deba24c689
commit 297990c454

View File

@ -56,13 +56,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return self.id return self.id
def __init__( def __init__(
self, self, upload_storage_uri, cache_dir, model_id=None, upload_storage_suffix="models", session=None, log=None
upload_storage_uri,
cache_dir,
model_id=None,
upload_storage_suffix="models",
session=None,
log=None
): ):
super(Model, self).__init__(id=model_id, session=session, log=log) super(Model, self).__init__(id=model_id, session=session, log=log)
self._upload_storage_suffix = upload_storage_suffix self._upload_storage_suffix = upload_storage_suffix
@ -84,9 +78,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
self.reload() self.reload()
else: else:
from ..model import BaseModel from ..model import BaseModel
# edit will reload # edit will reload
self._edit( self._edit(
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) | {BaseModel._archived_tag}) system_tags=list(
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
| {BaseModel._archived_tag}
)
) )
def unarchive(self): def unarchive(self):
@ -95,9 +93,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
self.reload() self.reload()
else: else:
from ..model import BaseModel from ..model import BaseModel
# edit will reload # edit will reload
self._edit( self._edit(
system_tags=list(set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else []) - {BaseModel._archived_tag}) system_tags=list(
set((self.data.system_tags or []) if hasattr(self.data, "system_tags") else [])
- {BaseModel._archived_tag}
)
) )
def _reload(self): def _reload(self):
@ -110,9 +112,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
res = self.send(models.GetByIdRequest(model=self.id)) res = self.send(models.GetByIdRequest(model=self.id))
return res.response.model return res.response.model
def _upload_model( def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
self, model_file, async_enable=False, target_filename=None, cb=None
):
if not self.upload_storage_uri: if not self.upload_storage_uri:
raise ValueError("Model has no storage URI defined (nowhere to upload to)") raise ValueError("Model has no storage URI defined (nowhere to upload to)")
target_filename = target_filename or Path(model_file).name target_filename = target_filename or Path(model_file).name
@ -128,15 +128,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
dest_path=dest_path, dest_path=dest_path,
async_enable=async_enable, async_enable=async_enable,
cb=partial(self._upload_callback, cb=cb), cb=partial(self._upload_callback, cb=cb),
return_canonized=False return_canonized=False,
) )
if async_enable: if async_enable:
def msg(num_results): def msg(num_results):
self.log.info( self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
"Waiting for previous model to upload (%d pending, %s)"
% (num_results, dest_path)
)
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg) self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
return dest_path return dest_path
@ -206,9 +203,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return list(design.values())[0] return list(design.values())[0]
raise ValueError( raise ValueError("design must be a string or a dictionary with at least one value")
"design must be a string or a dictionary with at least one value"
)
def update( def update(
self, self,
@ -226,7 +221,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
upload_storage_uri=None, upload_storage_uri=None,
target_filename=None, target_filename=None,
iteration=None, iteration=None,
system_tags=None system_tags=None,
): ):
"""Update model weights file and various model properties""" """Update model weights file and various model properties"""
@ -241,11 +236,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
# upload model file if needed and get uri # upload model file if needed and get uri
uri = uri or ( uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
self._upload_model(model_file, target_filename=target_filename)
if model_file
else self.data.uri
)
# update fields # update fields
design = self._wrap_design(design) if design else self.data.design design = self._wrap_design(design) if design else self.data.design
name = name or self.data.name name = name or self.data.name
@ -283,7 +274,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
uri=None, uri=None,
framework=None, framework=None,
iteration=None, iteration=None,
system_tags=None system_tags=None,
): ):
return self._edit( return self._edit(
design=design, design=design,
@ -369,7 +360,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
async_enable=False, async_enable=False,
target_filename=None, target_filename=None,
cb=None, cb=None,
iteration=None iteration=None,
): ):
"""Update the given model for a given task ID""" """Update the given model for a given task ID"""
if async_enable: if async_enable:
@ -417,9 +408,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
) )
return uri return uri
else: else:
uri = self._upload_model( uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
model_file, async_enable=async_enable, target_filename=target_filename
)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
self.update( self.update(
uri=uri, uri=uri,
@ -436,9 +425,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return uri return uri
def update_for_task( def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None):
self, task_id, name=None, model_id=None, type_="output", iteration=None
):
if Session.check_min_api_version("2.13"): if Session.check_min_api_version("2.13"):
req = tasks.AddOrUpdateModelRequest( req = tasks.AddOrUpdateModelRequest(
task=task_id, name=name, type=type_, model=model_id, iteration=iteration task=task_id, name=name, type=type_, model=model_id, iteration=iteration
@ -450,9 +437,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
# backwards compatibility, None # backwards compatibility, None
req = None req = None
else: else:
raise ValueError( raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_))
"Type '{}' unsupported (use either 'input' or 'output')".format(type_)
)
if req: if req:
self.send(req) self.send(req)
@ -502,11 +487,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
@property @property
def tags(self): def tags(self):
return ( return self.data.system_tags if hasattr(self.data, "system_tags") else self.data.tags
self.data.system_tags
if hasattr(self.data, "system_tags")
else self.data.tags
)
@property @property
def task(self): def task(self):
@ -549,11 +530,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return None return None
# check if we already downloaded the file # check if we already downloaded the file
downloaded_models = [ downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
k
for k, (i, u) in Model._local_model_to_id_uri.items()
if i == self.id and u == uri
]
for dl_file in downloaded_models: for dl_file in downloaded_models:
if Path(dl_file).exists() and not force_download: if Path(dl_file).exists() and not force_download:
return dl_file return dl_file
@ -593,9 +570,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def get_model_package(self): def get_model_package(self):
"""Get a named tuple containing the model's weights and design""" """Get a named tuple containing the model's weights and design"""
return ModelPackage( return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
weights=self.download_model_weights(), design=self.save_model_design_file()
)
def get_model_design(self): def get_model_design(self):
"""Get model description (text)""" """Get model description (text)"""