Allow registering offline models

This commit is contained in:
Alex Burlacu 2023-03-23 18:32:53 +02:00
parent a5d25b1a88
commit dc5be02328
3 changed files with 77 additions and 0 deletions

View File

@ -152,6 +152,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param force_create: If True, a new task will always be created (task_id, if provided, will be ignored)
:type force_create: bool
"""
self._offline_output_models = []
SingletonLock.instantiate()
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
self.__edit_lock = None
@ -2459,6 +2460,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
export_data = self.data.to_dict()
export_data["project_name"] = self.get_project_name()
export_data["offline_folder"] = self.get_offline_mode_folder().as_posix()
export_data["offline_output_models"] = self._offline_output_models
json.dump(export_data, f, ensure_ascii=True, sort_keys=True)
return None

View File

@ -2,6 +2,7 @@ import abc
import os
import tarfile
import zipfile
import shutil
from tempfile import mkdtemp, mkstemp
import six
@ -1166,6 +1167,7 @@ class OutputModel(BaseModel):
"""
_default_output_uri = None
_offline_folder = "models"
@property
def published(self):
@ -1324,6 +1326,7 @@ class OutputModel(BaseModel):
self._base_model = None
self._base_model_id = None
self._task_connect_name = None
self._label_enumeration = label_enumeration
# noinspection PyProtectedMember
self._floating_data = create_dummy_model(
design=_Model._wrap_design(config_text),
@ -1504,6 +1507,17 @@ class OutputModel(BaseModel):
if not self._task:
raise Exception('Missing a task for this model')
if self._task.is_offline() and (weights_filename is None or not Path(weights_filename).is_dir()):
return self._update_weights_offline(
weights_filename=weights_filename,
upload_uri=upload_uri,
target_filename=target_filename,
register_uri=register_uri,
iteration=iteration,
update_comment=update_comment,
is_package=is_package,
)
if weights_filename is not None:
# Check if weights_filename is a folder, is package upload
if Path(weights_filename).is_dir():
@ -1771,6 +1785,59 @@ class OutputModel(BaseModel):
"""
cls._default_output_uri = str(output_uri) if output_uri else None
def _update_weights_offline(
self,
weights_filename=None, # type: Optional[str]
upload_uri=None, # type: Optional[str]
target_filename=None, # type: Optional[str]
register_uri=None, # type: Optional[str]
iteration=None, # type: Optional[int]
update_comment=True, # type: bool
is_package=False, # type: bool
):
# type: (...) -> str
if (not weights_filename and not register_uri) or (weights_filename and register_uri):
raise ValueError(
"Model update must have either local weights file to upload, "
"or pre-uploaded register_uri, never both"
)
if not self._task:
raise Exception("Missing a task for this model")
weights_filename_offline = None
if weights_filename:
weights_filename_offline = (
self._task.get_offline_mode_folder() / self._offline_folder / Path(weights_filename).name
).as_posix()
os.makedirs(os.path.dirname(weights_filename_offline), exist_ok=True)
shutil.copyfile(weights_filename, weights_filename_offline)
# noinspection PyProtectedMember
self._task._offline_output_models.append(
dict(
init=dict(
config_text=self.config_text,
config_dict=self.config_dict,
label_enumeration=self._label_enumeration,
name=self.name,
tags=self.tags,
comment=self.comment,
framework=self.framework
),
weights=dict(
weights_filename=weights_filename_offline,
upload_uri=upload_uri,
target_filename=target_filename,
register_uri=register_uri,
iteration=iteration,
update_comment=update_comment,
is_package=is_package
),
output_uri=self._get_base_model().upload_storage_uri or self._default_output_uri
)
)
return weights_filename_offline or register_uri
def _get_force_base_model(self, model_name=None, task_model_entry=None):
if self._base_model:
return self._base_model

View File

@ -2930,6 +2930,11 @@ class Task(_Task):
StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri)
# noinspection PyProtectedMember
task_holding_reports._edit(execution=current_task.data.execution)
for output_model in export_data.get("offline_output_models", []):
model = OutputModel(task=current_task, **output_model["init"])
if output_model.get("output_uri"):
model.set_upload_destination(output_model.get("output_uri"))
model.update_weights(auto_delete_file=False, **output_model["weights"])
# logs
TaskHandler.report_offline_session(task_holding_reports, session_folder, iteration_offset=iteration_offset)
# metrics
@ -3833,6 +3838,9 @@ class Task(_Task):
if self._offline_mode and not is_sub_process:
# noinspection PyBroadException
try:
# make sure the state of the offline data is saved
self._edit()
# create zip file
offline_folder = self.get_offline_mode_folder()
zip_file = offline_folder.as_posix() + '.zip'