mirror of
https://github.com/clearml/clearml
synced 2025-04-07 22:24:30 +00:00
Create OutputModel base model lazily
This commit is contained in:
parent
2c44bff461
commit
d4b11dfa22
144
clearml/model.py
144
clearml/model.py
@ -357,6 +357,9 @@ class BaseModel(object):
|
|||||||
self._task = None
|
self._task = None
|
||||||
self._reload_required = False
|
self._reload_required = False
|
||||||
self._reporter = None
|
self._reporter = None
|
||||||
|
self._floating_data = None
|
||||||
|
self._name = None
|
||||||
|
self._task_connect_name = None
|
||||||
self._set_task(task)
|
self._set_task(task)
|
||||||
|
|
||||||
def get_weights(self, raise_on_error=False, force_download=False):
|
def get_weights(self, raise_on_error=False, force_download=False):
|
||||||
@ -1055,6 +1058,7 @@ class BaseModel(object):
|
|||||||
def _init_reporter(self):
|
def _init_reporter(self):
|
||||||
if self._reporter:
|
if self._reporter:
|
||||||
return
|
return
|
||||||
|
self._base_model = self._get_force_base_model()
|
||||||
metrics_manager = Metrics(
|
metrics_manager = Metrics(
|
||||||
session=_Model._get_default_session(),
|
session=_Model._get_default_session(),
|
||||||
storage_uri=None,
|
storage_uri=None,
|
||||||
@ -1126,6 +1130,8 @@ class BaseModel(object):
|
|||||||
|
|
||||||
:return: True if the metadata was set and False otherwise
|
:return: True if the metadata was set and False otherwise
|
||||||
"""
|
"""
|
||||||
|
if not self._base_model:
|
||||||
|
self._base_model = self._get_force_base_model()
|
||||||
self._reload_required = (
|
self._reload_required = (
|
||||||
_Model._get_default_session()
|
_Model._get_default_session()
|
||||||
.send(
|
.send(
|
||||||
@ -1167,6 +1173,8 @@ class BaseModel(object):
|
|||||||
|
|
||||||
:return: String representation of the value of the metadata entry or None if the entry was not found
|
:return: String representation of the value of the metadata entry or None if the entry was not found
|
||||||
"""
|
"""
|
||||||
|
if not self._base_model:
|
||||||
|
self._base_model = self._get_force_base_model()
|
||||||
self._reload_if_required()
|
self._reload_if_required()
|
||||||
return self.get_all_metadata().get(str(key), {}).get("value")
|
return self.get_all_metadata().get(str(key), {}).get("value")
|
||||||
|
|
||||||
@ -1180,6 +1188,8 @@ class BaseModel(object):
|
|||||||
:return: The value of the metadata entry, casted to its type (if not possible,
|
:return: The value of the metadata entry, casted to its type (if not possible,
|
||||||
the string representation will be returned) or None if the entry was not found
|
the string representation will be returned) or None if the entry was not found
|
||||||
"""
|
"""
|
||||||
|
if not self._base_model:
|
||||||
|
self._base_model = self._get_force_base_model()
|
||||||
key = str(key)
|
key = str(key)
|
||||||
metadata = self.get_all_metadata()
|
metadata = self.get_all_metadata()
|
||||||
if key not in metadata:
|
if key not in metadata:
|
||||||
@ -1194,6 +1204,8 @@ class BaseModel(object):
|
|||||||
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
|
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
|
||||||
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
|
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
|
||||||
"""
|
"""
|
||||||
|
if not self._base_model:
|
||||||
|
self._base_model = self._get_force_base_model()
|
||||||
self._reload_if_required()
|
self._reload_if_required()
|
||||||
return self._get_model_data().metadata or {}
|
return self._get_model_data().metadata or {}
|
||||||
|
|
||||||
@ -1204,6 +1216,8 @@ class BaseModel(object):
|
|||||||
entries are strings. The value is cast to its type if possible. Note that each entry might
|
entries are strings. The value is cast to its type if possible. Note that each entry might
|
||||||
have an additional 'key' entry, repeating the key
|
have an additional 'key' entry, repeating the key
|
||||||
"""
|
"""
|
||||||
|
if not self._base_model:
|
||||||
|
self._base_model = self._get_force_base_model()
|
||||||
self._reload_if_required()
|
self._reload_if_required()
|
||||||
result = {}
|
result = {}
|
||||||
metadata = self.get_all_metadata()
|
metadata = self.get_all_metadata()
|
||||||
@ -1224,6 +1238,8 @@ class BaseModel(object):
|
|||||||
|
|
||||||
:return: True if the metadata was set and False otherwise
|
:return: True if the metadata was set and False otherwise
|
||||||
"""
|
"""
|
||||||
|
if not self._base_model:
|
||||||
|
self._base_model = self._get_force_base_model()
|
||||||
metadata_array = [
|
metadata_array = [
|
||||||
{
|
{
|
||||||
"key": str(k),
|
"key": str(k),
|
||||||
@ -1249,6 +1265,74 @@ class BaseModel(object):
|
|||||||
self._get_base_model().reload()
|
self._get_base_model().reload()
|
||||||
self._reload_required = False
|
self._reload_required = False
|
||||||
|
|
||||||
|
def _update_base_model(self, model_name=None, task_model_entry=None):
|
||||||
|
if not self._task:
|
||||||
|
return self._base_model
|
||||||
|
# update the model from the task inputs
|
||||||
|
labels = self._task.get_labels_enumeration()
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
config_text = self._task._get_model_config_text()
|
||||||
|
model_name = (
|
||||||
|
model_name or self._name or (self._floating_data.name if self._floating_data else None) or self._task.name
|
||||||
|
)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
task_model_entry = (
|
||||||
|
task_model_entry
|
||||||
|
or self._task_connect_name
|
||||||
|
or Path(self._get_model_data().uri).stem
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
parent = self._task.input_models_id.get(task_model_entry)
|
||||||
|
self._base_model.update(
|
||||||
|
labels=(self._floating_data.labels if self._floating_data else None) or labels,
|
||||||
|
design=(self._floating_data.design if self._floating_data else None) or config_text,
|
||||||
|
task_id=self._task.id,
|
||||||
|
project_id=self._task.project,
|
||||||
|
parent_id=parent,
|
||||||
|
name=model_name,
|
||||||
|
comment=self._floating_data.comment if self._floating_data else None,
|
||||||
|
tags=self._floating_data.tags if self._floating_data else None,
|
||||||
|
framework=self._floating_data.framework if self._floating_data else None,
|
||||||
|
upload_storage_uri=self._floating_data.upload_storage_uri if self._floating_data else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove model floating change set, by now they should have matched the task.
|
||||||
|
self._floating_data = None
|
||||||
|
|
||||||
|
# now we have to update the creator task so it points to us
|
||||||
|
if str(self._task.status) not in (
|
||||||
|
str(self._task.TaskStatusEnum.created),
|
||||||
|
str(self._task.TaskStatusEnum.in_progress),
|
||||||
|
):
|
||||||
|
self._log.warning(
|
||||||
|
"Could not update last created model in Task {}, "
|
||||||
|
"Task status '{}' cannot be updated".format(
|
||||||
|
self._task.id, self._task.status
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif task_model_entry:
|
||||||
|
self._base_model.update_for_task(
|
||||||
|
task_id=self._task.id,
|
||||||
|
model_id=self.id,
|
||||||
|
type_="output",
|
||||||
|
name=task_model_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._base_model
|
||||||
|
|
||||||
|
def _get_force_base_model(self, model_name=None, task_model_entry=None):
|
||||||
|
if self._base_model:
|
||||||
|
return self._base_model
|
||||||
|
if not self._task:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# create a new model from the task
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
self._base_model = self._task._get_output_model(model_id=None)
|
||||||
|
return self._update_base_model(model_name=model_name, task_model_entry=task_model_entry)
|
||||||
|
|
||||||
|
|
||||||
class Model(BaseModel):
|
class Model(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -2060,6 +2144,7 @@ class OutputModel(BaseModel):
|
|||||||
self._base_model = None
|
self._base_model = None
|
||||||
self._base_model_id = None
|
self._base_model_id = None
|
||||||
self._task_connect_name = None
|
self._task_connect_name = None
|
||||||
|
self._name = name
|
||||||
self._label_enumeration = label_enumeration
|
self._label_enumeration = label_enumeration
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
self._floating_data = create_dummy_model(
|
self._floating_data = create_dummy_model(
|
||||||
@ -2300,7 +2385,11 @@ class OutputModel(BaseModel):
|
|||||||
if out_model_file_name
|
if out_model_file_name
|
||||||
else (self._task_connect_name or "Output Model")
|
else (self._task_connect_name or "Output Model")
|
||||||
)
|
)
|
||||||
|
if not self._base_model:
|
||||||
model = self._get_force_base_model(task_model_entry=name)
|
model = self._get_force_base_model(task_model_entry=name)
|
||||||
|
else:
|
||||||
|
self._update_base_model(task_model_entry=name)
|
||||||
|
model = self._base_model
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError("Failed creating internal output model")
|
raise ValueError("Failed creating internal output model")
|
||||||
|
|
||||||
@ -2639,61 +2728,6 @@ class OutputModel(BaseModel):
|
|||||||
)
|
)
|
||||||
return weights_filename_offline or register_uri
|
return weights_filename_offline or register_uri
|
||||||
|
|
||||||
def _get_force_base_model(self, model_name=None, task_model_entry=None):
|
|
||||||
if self._base_model:
|
|
||||||
return self._base_model
|
|
||||||
|
|
||||||
# create a new model from the task
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
self._base_model = self._task._get_output_model(model_id=None)
|
|
||||||
# update the model from the task inputs
|
|
||||||
labels = self._task.get_labels_enumeration()
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
config_text = self._task._get_model_config_text()
|
|
||||||
model_name = model_name or self._floating_data.name or self._task.name
|
|
||||||
task_model_entry = (
|
|
||||||
task_model_entry
|
|
||||||
or self._task_connect_name
|
|
||||||
or Path(self._get_model_data().uri).stem
|
|
||||||
)
|
|
||||||
parent = self._task.input_models_id.get(task_model_entry)
|
|
||||||
self._base_model.update(
|
|
||||||
labels=self._floating_data.labels or labels,
|
|
||||||
design=self._floating_data.design or config_text,
|
|
||||||
task_id=self._task.id,
|
|
||||||
project_id=self._task.project,
|
|
||||||
parent_id=parent,
|
|
||||||
name=model_name,
|
|
||||||
comment=self._floating_data.comment,
|
|
||||||
tags=self._floating_data.tags,
|
|
||||||
framework=self._floating_data.framework,
|
|
||||||
upload_storage_uri=self._floating_data.upload_storage_uri,
|
|
||||||
)
|
|
||||||
|
|
||||||
# remove model floating change set, by now they should have matched the task.
|
|
||||||
self._floating_data = None
|
|
||||||
|
|
||||||
# now we have to update the creator task so it points to us
|
|
||||||
if str(self._task.status) not in (
|
|
||||||
str(self._task.TaskStatusEnum.created),
|
|
||||||
str(self._task.TaskStatusEnum.in_progress),
|
|
||||||
):
|
|
||||||
self._log.warning(
|
|
||||||
"Could not update last created model in Task {}, "
|
|
||||||
"Task status '{}' cannot be updated".format(
|
|
||||||
self._task.id, self._task.status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._base_model.update_for_task(
|
|
||||||
task_id=self._task.id,
|
|
||||||
model_id=self.id,
|
|
||||||
type_="output",
|
|
||||||
name=task_model_entry,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._base_model
|
|
||||||
|
|
||||||
def _get_base_model(self):
|
def _get_base_model(self):
|
||||||
if self._floating_data:
|
if self._floating_data:
|
||||||
return self._floating_data
|
return self._floating_data
|
||||||
|
Loading…
Reference in New Issue
Block a user