mirror of
https://github.com/clearml/clearml
synced 2025-04-07 06:04:25 +00:00
Create OutputModel base model lazily
This commit is contained in:
parent
2c44bff461
commit
d4b11dfa22
146
clearml/model.py
146
clearml/model.py
@ -357,6 +357,9 @@ class BaseModel(object):
|
||||
self._task = None
|
||||
self._reload_required = False
|
||||
self._reporter = None
|
||||
self._floating_data = None
|
||||
self._name = None
|
||||
self._task_connect_name = None
|
||||
self._set_task(task)
|
||||
|
||||
def get_weights(self, raise_on_error=False, force_download=False):
|
||||
@ -1055,6 +1058,7 @@ class BaseModel(object):
|
||||
def _init_reporter(self):
|
||||
if self._reporter:
|
||||
return
|
||||
self._base_model = self._get_force_base_model()
|
||||
metrics_manager = Metrics(
|
||||
session=_Model._get_default_session(),
|
||||
storage_uri=None,
|
||||
@ -1126,6 +1130,8 @@ class BaseModel(object):
|
||||
|
||||
:return: True if the metadata was set and False otherwise
|
||||
"""
|
||||
if not self._base_model:
|
||||
self._base_model = self._get_force_base_model()
|
||||
self._reload_required = (
|
||||
_Model._get_default_session()
|
||||
.send(
|
||||
@ -1167,6 +1173,8 @@ class BaseModel(object):
|
||||
|
||||
:return: String representation of the value of the metadata entry or None if the entry was not found
|
||||
"""
|
||||
if not self._base_model:
|
||||
self._base_model = self._get_force_base_model()
|
||||
self._reload_if_required()
|
||||
return self.get_all_metadata().get(str(key), {}).get("value")
|
||||
|
||||
@ -1180,6 +1188,8 @@ class BaseModel(object):
|
||||
:return: The value of the metadata entry, casted to its type (if not possible,
|
||||
the string representation will be returned) or None if the entry was not found
|
||||
"""
|
||||
if not self._base_model:
|
||||
self._base_model = self._get_force_base_model()
|
||||
key = str(key)
|
||||
metadata = self.get_all_metadata()
|
||||
if key not in metadata:
|
||||
@ -1194,6 +1204,8 @@ class BaseModel(object):
|
||||
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
|
||||
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
|
||||
"""
|
||||
if not self._base_model:
|
||||
self._base_model = self._get_force_base_model()
|
||||
self._reload_if_required()
|
||||
return self._get_model_data().metadata or {}
|
||||
|
||||
@ -1204,6 +1216,8 @@ class BaseModel(object):
|
||||
entries are strings. The value is cast to its type if possible. Note that each entry might
|
||||
have an additional 'key' entry, repeating the key
|
||||
"""
|
||||
if not self._base_model:
|
||||
self._base_model = self._get_force_base_model()
|
||||
self._reload_if_required()
|
||||
result = {}
|
||||
metadata = self.get_all_metadata()
|
||||
@ -1224,6 +1238,8 @@ class BaseModel(object):
|
||||
|
||||
:return: True if the metadata was set and False otherwise
|
||||
"""
|
||||
if not self._base_model:
|
||||
self._base_model = self._get_force_base_model()
|
||||
metadata_array = [
|
||||
{
|
||||
"key": str(k),
|
||||
@ -1249,6 +1265,74 @@ class BaseModel(object):
|
||||
self._get_base_model().reload()
|
||||
self._reload_required = False
|
||||
|
||||
def _update_base_model(self, model_name=None, task_model_entry=None):
|
||||
if not self._task:
|
||||
return self._base_model
|
||||
# update the model from the task inputs
|
||||
labels = self._task.get_labels_enumeration()
|
||||
# noinspection PyProtectedMember
|
||||
config_text = self._task._get_model_config_text()
|
||||
model_name = (
|
||||
model_name or self._name or (self._floating_data.name if self._floating_data else None) or self._task.name
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
task_model_entry = (
|
||||
task_model_entry
|
||||
or self._task_connect_name
|
||||
or Path(self._get_model_data().uri).stem
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
parent = self._task.input_models_id.get(task_model_entry)
|
||||
self._base_model.update(
|
||||
labels=(self._floating_data.labels if self._floating_data else None) or labels,
|
||||
design=(self._floating_data.design if self._floating_data else None) or config_text,
|
||||
task_id=self._task.id,
|
||||
project_id=self._task.project,
|
||||
parent_id=parent,
|
||||
name=model_name,
|
||||
comment=self._floating_data.comment if self._floating_data else None,
|
||||
tags=self._floating_data.tags if self._floating_data else None,
|
||||
framework=self._floating_data.framework if self._floating_data else None,
|
||||
upload_storage_uri=self._floating_data.upload_storage_uri if self._floating_data else None,
|
||||
)
|
||||
|
||||
# remove model floating change set, by now they should have matched the task.
|
||||
self._floating_data = None
|
||||
|
||||
# now we have to update the creator task so it points to us
|
||||
if str(self._task.status) not in (
|
||||
str(self._task.TaskStatusEnum.created),
|
||||
str(self._task.TaskStatusEnum.in_progress),
|
||||
):
|
||||
self._log.warning(
|
||||
"Could not update last created model in Task {}, "
|
||||
"Task status '{}' cannot be updated".format(
|
||||
self._task.id, self._task.status
|
||||
)
|
||||
)
|
||||
elif task_model_entry:
|
||||
self._base_model.update_for_task(
|
||||
task_id=self._task.id,
|
||||
model_id=self.id,
|
||||
type_="output",
|
||||
name=task_model_entry,
|
||||
)
|
||||
|
||||
return self._base_model
|
||||
|
||||
def _get_force_base_model(self, model_name=None, task_model_entry=None):
|
||||
if self._base_model:
|
||||
return self._base_model
|
||||
if not self._task:
|
||||
return None
|
||||
|
||||
# create a new model from the task
|
||||
# noinspection PyProtectedMember
|
||||
self._base_model = self._task._get_output_model(model_id=None)
|
||||
return self._update_base_model(model_name=model_name, task_model_entry=task_model_entry)
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
"""
|
||||
@ -2060,6 +2144,7 @@ class OutputModel(BaseModel):
|
||||
self._base_model = None
|
||||
self._base_model_id = None
|
||||
self._task_connect_name = None
|
||||
self._name = name
|
||||
self._label_enumeration = label_enumeration
|
||||
# noinspection PyProtectedMember
|
||||
self._floating_data = create_dummy_model(
|
||||
@ -2300,7 +2385,11 @@ class OutputModel(BaseModel):
|
||||
if out_model_file_name
|
||||
else (self._task_connect_name or "Output Model")
|
||||
)
|
||||
model = self._get_force_base_model(task_model_entry=name)
|
||||
if not self._base_model:
|
||||
model = self._get_force_base_model(task_model_entry=name)
|
||||
else:
|
||||
self._update_base_model(task_model_entry=name)
|
||||
model = self._base_model
|
||||
if not model:
|
||||
raise ValueError("Failed creating internal output model")
|
||||
|
||||
@ -2639,61 +2728,6 @@ class OutputModel(BaseModel):
|
||||
)
|
||||
return weights_filename_offline or register_uri
|
||||
|
||||
def _get_force_base_model(self, model_name=None, task_model_entry=None):
|
||||
if self._base_model:
|
||||
return self._base_model
|
||||
|
||||
# create a new model from the task
|
||||
# noinspection PyProtectedMember
|
||||
self._base_model = self._task._get_output_model(model_id=None)
|
||||
# update the model from the task inputs
|
||||
labels = self._task.get_labels_enumeration()
|
||||
# noinspection PyProtectedMember
|
||||
config_text = self._task._get_model_config_text()
|
||||
model_name = model_name or self._floating_data.name or self._task.name
|
||||
task_model_entry = (
|
||||
task_model_entry
|
||||
or self._task_connect_name
|
||||
or Path(self._get_model_data().uri).stem
|
||||
)
|
||||
parent = self._task.input_models_id.get(task_model_entry)
|
||||
self._base_model.update(
|
||||
labels=self._floating_data.labels or labels,
|
||||
design=self._floating_data.design or config_text,
|
||||
task_id=self._task.id,
|
||||
project_id=self._task.project,
|
||||
parent_id=parent,
|
||||
name=model_name,
|
||||
comment=self._floating_data.comment,
|
||||
tags=self._floating_data.tags,
|
||||
framework=self._floating_data.framework,
|
||||
upload_storage_uri=self._floating_data.upload_storage_uri,
|
||||
)
|
||||
|
||||
# remove model floating change set, by now they should have matched the task.
|
||||
self._floating_data = None
|
||||
|
||||
# now we have to update the creator task so it points to us
|
||||
if str(self._task.status) not in (
|
||||
str(self._task.TaskStatusEnum.created),
|
||||
str(self._task.TaskStatusEnum.in_progress),
|
||||
):
|
||||
self._log.warning(
|
||||
"Could not update last created model in Task {}, "
|
||||
"Task status '{}' cannot be updated".format(
|
||||
self._task.id, self._task.status
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._base_model.update_for_task(
|
||||
task_id=self._task.id,
|
||||
model_id=self.id,
|
||||
type_="output",
|
||||
name=task_model_entry,
|
||||
)
|
||||
|
||||
return self._base_model
|
||||
|
||||
def _get_base_model(self):
|
||||
if self._floating_data:
|
||||
return self._floating_data
|
||||
|
Loading…
Reference in New Issue
Block a user