Add raise_on_error flag to Model.update_weights() and Model.update_weights_package()

This commit is contained in:
allegroai 2023-12-24 11:37:59 +02:00
parent 52d8835710
commit 7813602ad2
2 changed files with 13 additions and 3 deletions

View File

@ -528,8 +528,10 @@ class _Arguments(object):
"Failed parsing task parameter {}={} keeping default {}={}".format(k, param, k, v)
)
# assume more general purpose type int -> float
if v_type == int:
# if parameter is empty and default value is None, keep as None
if param == '' and v is None:
v_type = type(None)
elif v_type == int: # assume more general purpose type int -> float
if v is not None and int(v) != float(v):
v_type = float
elif v_type == bool:

View File

@ -2347,6 +2347,7 @@ class OutputModel(BaseModel):
iteration=None, # type: Optional[int]
update_comment=True, # type: bool
is_package=False, # type: bool
async_enable=True, # type: bool
):
# type: (...) -> str
"""
@ -2374,6 +2375,8 @@ class OutputModel(BaseModel):
- ``True`` - Update model comment (Default)
- ``False`` - Do not update
:param bool is_package: Mark the weights file as compressed package, usually a zip file.
:param bool async_enable: Whether to upload model in background or to block.
Will raise an error in the main thread if the weights failed to be uploaded or not.
:return: The uploaded URI.
"""
@ -2421,6 +2424,7 @@ class OutputModel(BaseModel):
target_filename=target_filename or Path(weights_filename).name,
auto_delete_file=auto_delete_file,
iteration=iteration,
async_enable=async_enable
)
# make sure we delete the previous file, if it exists
@ -2502,7 +2506,7 @@ class OutputModel(BaseModel):
output_uri = model.update_and_upload(
model_file=weights_filename,
task_id=self._task.id,
async_enable=True,
async_enable=async_enable,
target_filename=target_filename,
framework=self.framework or framework,
comment=comment,
@ -2535,6 +2539,7 @@ class OutputModel(BaseModel):
target_filename=None, # type: Optional[str]
auto_delete_file=True, # type: bool
iteration=None, # type: Optional[int]
async_enable=True, # type: bool
):
# type: (...) -> str
"""
@ -2559,6 +2564,8 @@ class OutputModel(BaseModel):
- ``False`` - Do not delete
:param int iteration: The iteration number.
:param bool async_enable: Whether to upload model in background or to block.
Will raise an error in the main thread if the weights failed to be uploaded or not.
:return: The uploaded URI for the weights package.
"""
@ -2626,6 +2633,7 @@ class OutputModel(BaseModel):
target_filename=target_filename or "model_package.zip",
iteration=iteration,
update_comment=False,
async_enable=async_enable
)
# set the model tag (by now we should have a model object) so we know we have packaged file
self._set_package_tag()