Create OutputModel base model lazily

This commit is contained in:
Alex Burlacu 2023-08-11 13:11:58 +03:00
parent 2c44bff461
commit d4b11dfa22

View File

@ -357,6 +357,9 @@ class BaseModel(object):
self._task = None
self._reload_required = False
self._reporter = None
self._floating_data = None
self._name = None
self._task_connect_name = None
self._set_task(task)
def get_weights(self, raise_on_error=False, force_download=False):
@ -1055,6 +1058,7 @@ class BaseModel(object):
def _init_reporter(self):
if self._reporter:
return
self._base_model = self._get_force_base_model()
metrics_manager = Metrics(
session=_Model._get_default_session(),
storage_uri=None,
@ -1126,6 +1130,8 @@ class BaseModel(object):
:return: True if the metadata was set and False otherwise
"""
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_required = (
_Model._get_default_session()
.send(
@ -1167,6 +1173,8 @@ class BaseModel(object):
:return: String representation of the value of the metadata entry or None if the entry was not found
"""
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_if_required()
return self.get_all_metadata().get(str(key), {}).get("value")
@ -1180,6 +1188,8 @@ class BaseModel(object):
:return: The value of the metadata entry, casted to its type (if not possible,
the string representation will be returned) or None if the entry was not found
"""
if not self._base_model:
self._base_model = self._get_force_base_model()
key = str(key)
metadata = self.get_all_metadata()
if key not in metadata:
@ -1194,6 +1204,8 @@ class BaseModel(object):
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
"""
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_if_required()
return self._get_model_data().metadata or {}
@ -1204,6 +1216,8 @@ class BaseModel(object):
entries are strings. The value is cast to its type if possible. Note that each entry might
have an additional 'key' entry, repeating the key
"""
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_if_required()
result = {}
metadata = self.get_all_metadata()
@ -1224,6 +1238,8 @@ class BaseModel(object):
:return: True if the metadata was set and False otherwise
"""
if not self._base_model:
self._base_model = self._get_force_base_model()
metadata_array = [
{
"key": str(k),
@ -1249,6 +1265,74 @@ class BaseModel(object):
self._get_base_model().reload()
self._reload_required = False
def _update_base_model(self, model_name=None, task_model_entry=None):
if not self._task:
return self._base_model
# update the model from the task inputs
labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
config_text = self._task._get_model_config_text()
model_name = (
model_name or self._name or (self._floating_data.name if self._floating_data else None) or self._task.name
)
# noinspection PyBroadException
try:
task_model_entry = (
task_model_entry
or self._task_connect_name
or Path(self._get_model_data().uri).stem
)
except Exception:
pass
parent = self._task.input_models_id.get(task_model_entry)
self._base_model.update(
labels=(self._floating_data.labels if self._floating_data else None) or labels,
design=(self._floating_data.design if self._floating_data else None) or config_text,
task_id=self._task.id,
project_id=self._task.project,
parent_id=parent,
name=model_name,
comment=self._floating_data.comment if self._floating_data else None,
tags=self._floating_data.tags if self._floating_data else None,
framework=self._floating_data.framework if self._floating_data else None,
upload_storage_uri=self._floating_data.upload_storage_uri if self._floating_data else None,
)
# remove model floating change set, by now they should have matched the task.
self._floating_data = None
# now we have to update the creator task so it points to us
if str(self._task.status) not in (
str(self._task.TaskStatusEnum.created),
str(self._task.TaskStatusEnum.in_progress),
):
self._log.warning(
"Could not update last created model in Task {}, "
"Task status '{}' cannot be updated".format(
self._task.id, self._task.status
)
)
elif task_model_entry:
self._base_model.update_for_task(
task_id=self._task.id,
model_id=self.id,
type_="output",
name=task_model_entry,
)
return self._base_model
def _get_force_base_model(self, model_name=None, task_model_entry=None):
if self._base_model:
return self._base_model
if not self._task:
return None
# create a new model from the task
# noinspection PyProtectedMember
self._base_model = self._task._get_output_model(model_id=None)
return self._update_base_model(model_name=model_name, task_model_entry=task_model_entry)
class Model(BaseModel):
"""
@ -2060,6 +2144,7 @@ class OutputModel(BaseModel):
self._base_model = None
self._base_model_id = None
self._task_connect_name = None
self._name = name
self._label_enumeration = label_enumeration
# noinspection PyProtectedMember
self._floating_data = create_dummy_model(
@ -2300,7 +2385,11 @@ class OutputModel(BaseModel):
if out_model_file_name
else (self._task_connect_name or "Output Model")
)
model = self._get_force_base_model(task_model_entry=name)
if not self._base_model:
model = self._get_force_base_model(task_model_entry=name)
else:
self._update_base_model(task_model_entry=name)
model = self._base_model
if not model:
raise ValueError("Failed creating internal output model")
@ -2639,61 +2728,6 @@ class OutputModel(BaseModel):
)
return weights_filename_offline or register_uri
def _get_force_base_model(self, model_name=None, task_model_entry=None):
if self._base_model:
return self._base_model
# create a new model from the task
# noinspection PyProtectedMember
self._base_model = self._task._get_output_model(model_id=None)
# update the model from the task inputs
labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
config_text = self._task._get_model_config_text()
model_name = model_name or self._floating_data.name or self._task.name
task_model_entry = (
task_model_entry
or self._task_connect_name
or Path(self._get_model_data().uri).stem
)
parent = self._task.input_models_id.get(task_model_entry)
self._base_model.update(
labels=self._floating_data.labels or labels,
design=self._floating_data.design or config_text,
task_id=self._task.id,
project_id=self._task.project,
parent_id=parent,
name=model_name,
comment=self._floating_data.comment,
tags=self._floating_data.tags,
framework=self._floating_data.framework,
upload_storage_uri=self._floating_data.upload_storage_uri,
)
# remove model floating change set, by now they should have matched the task.
self._floating_data = None
# now we have to update the creator task so it points to us
if str(self._task.status) not in (
str(self._task.TaskStatusEnum.created),
str(self._task.TaskStatusEnum.in_progress),
):
self._log.warning(
"Could not update last created model in Task {}, "
"Task status '{}' cannot be updated".format(
self._task.id, self._task.status
)
)
else:
self._base_model.update_for_task(
task_id=self._task.id,
model_id=self.id,
type_="output",
name=task_model_entry,
)
return self._base_model
def _get_base_model(self):
if self._floating_data:
return self._floating_data