mirror of
https://github.com/clearml/clearml
synced 2025-04-07 06:04:25 +00:00
Allow registering offline models
This commit is contained in:
parent
a5d25b1a88
commit
dc5be02328
@ -152,6 +152,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
:param force_create: If True, a new task will always be created (task_id, if provided, will be ignored)
|
||||
:type force_create: bool
|
||||
"""
|
||||
self._offline_output_models = []
|
||||
SingletonLock.instantiate()
|
||||
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
|
||||
self.__edit_lock = None
|
||||
@ -2459,6 +2460,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
export_data = self.data.to_dict()
|
||||
export_data["project_name"] = self.get_project_name()
|
||||
export_data["offline_folder"] = self.get_offline_mode_folder().as_posix()
|
||||
export_data["offline_output_models"] = self._offline_output_models
|
||||
json.dump(export_data, f, ensure_ascii=True, sort_keys=True)
|
||||
return None
|
||||
|
||||
|
@ -2,6 +2,7 @@ import abc
|
||||
import os
|
||||
import tarfile
|
||||
import zipfile
|
||||
import shutil
|
||||
from tempfile import mkdtemp, mkstemp
|
||||
|
||||
import six
|
||||
@ -1166,6 +1167,7 @@ class OutputModel(BaseModel):
|
||||
"""
|
||||
|
||||
_default_output_uri = None
|
||||
_offline_folder = "models"
|
||||
|
||||
@property
|
||||
def published(self):
|
||||
@ -1324,6 +1326,7 @@ class OutputModel(BaseModel):
|
||||
self._base_model = None
|
||||
self._base_model_id = None
|
||||
self._task_connect_name = None
|
||||
self._label_enumeration = label_enumeration
|
||||
# noinspection PyProtectedMember
|
||||
self._floating_data = create_dummy_model(
|
||||
design=_Model._wrap_design(config_text),
|
||||
@ -1504,6 +1507,17 @@ class OutputModel(BaseModel):
|
||||
if not self._task:
|
||||
raise Exception('Missing a task for this model')
|
||||
|
||||
if self._task.is_offline() and (weights_filename is None or not Path(weights_filename).is_dir()):
|
||||
return self._update_weights_offline(
|
||||
weights_filename=weights_filename,
|
||||
upload_uri=upload_uri,
|
||||
target_filename=target_filename,
|
||||
register_uri=register_uri,
|
||||
iteration=iteration,
|
||||
update_comment=update_comment,
|
||||
is_package=is_package,
|
||||
)
|
||||
|
||||
if weights_filename is not None:
|
||||
# Check if weights_filename is a folder, is package upload
|
||||
if Path(weights_filename).is_dir():
|
||||
@ -1771,6 +1785,59 @@ class OutputModel(BaseModel):
|
||||
"""
|
||||
cls._default_output_uri = str(output_uri) if output_uri else None
|
||||
|
||||
def _update_weights_offline(
|
||||
self,
|
||||
weights_filename=None, # type: Optional[str]
|
||||
upload_uri=None, # type: Optional[str]
|
||||
target_filename=None, # type: Optional[str]
|
||||
register_uri=None, # type: Optional[str]
|
||||
iteration=None, # type: Optional[int]
|
||||
update_comment=True, # type: bool
|
||||
is_package=False, # type: bool
|
||||
):
|
||||
# type: (...) -> str
|
||||
if (not weights_filename and not register_uri) or (weights_filename and register_uri):
|
||||
raise ValueError(
|
||||
"Model update must have either local weights file to upload, "
|
||||
"or pre-uploaded register_uri, never both"
|
||||
)
|
||||
if not self._task:
|
||||
raise Exception("Missing a task for this model")
|
||||
|
||||
weights_filename_offline = None
|
||||
if weights_filename:
|
||||
weights_filename_offline = (
|
||||
self._task.get_offline_mode_folder() / self._offline_folder / Path(weights_filename).name
|
||||
).as_posix()
|
||||
os.makedirs(os.path.dirname(weights_filename_offline), exist_ok=True)
|
||||
shutil.copyfile(weights_filename, weights_filename_offline)
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
self._task._offline_output_models.append(
|
||||
dict(
|
||||
init=dict(
|
||||
config_text=self.config_text,
|
||||
config_dict=self.config_dict,
|
||||
label_enumeration=self._label_enumeration,
|
||||
name=self.name,
|
||||
tags=self.tags,
|
||||
comment=self.comment,
|
||||
framework=self.framework
|
||||
),
|
||||
weights=dict(
|
||||
weights_filename=weights_filename_offline,
|
||||
upload_uri=upload_uri,
|
||||
target_filename=target_filename,
|
||||
register_uri=register_uri,
|
||||
iteration=iteration,
|
||||
update_comment=update_comment,
|
||||
is_package=is_package
|
||||
),
|
||||
output_uri=self._get_base_model().upload_storage_uri or self._default_output_uri
|
||||
)
|
||||
)
|
||||
return weights_filename_offline or register_uri
|
||||
|
||||
def _get_force_base_model(self, model_name=None, task_model_entry=None):
|
||||
if self._base_model:
|
||||
return self._base_model
|
||||
|
@ -2930,6 +2930,11 @@ class Task(_Task):
|
||||
StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri)
|
||||
# noinspection PyProtectedMember
|
||||
task_holding_reports._edit(execution=current_task.data.execution)
|
||||
for output_model in export_data.get("offline_output_models", []):
|
||||
model = OutputModel(task=current_task, **output_model["init"])
|
||||
if output_model.get("output_uri"):
|
||||
model.set_upload_destination(output_model.get("output_uri"))
|
||||
model.update_weights(auto_delete_file=False, **output_model["weights"])
|
||||
# logs
|
||||
TaskHandler.report_offline_session(task_holding_reports, session_folder, iteration_offset=iteration_offset)
|
||||
# metrics
|
||||
@ -3833,6 +3838,9 @@ class Task(_Task):
|
||||
if self._offline_mode and not is_sub_process:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure the state of the offline data is saved
|
||||
self._edit()
|
||||
|
||||
# create zip file
|
||||
offline_folder = self.get_offline_mode_folder()
|
||||
zip_file = offline_folder.as_posix() + '.zip'
|
||||
|
Loading…
Reference in New Issue
Block a user