Create OutputModel base model lazily

This commit is contained in:
Alex Burlacu 2023-08-11 13:11:58 +03:00
parent 2c44bff461
commit d4b11dfa22

View File

@ -357,6 +357,9 @@ class BaseModel(object):
self._task = None self._task = None
self._reload_required = False self._reload_required = False
self._reporter = None self._reporter = None
self._floating_data = None
self._name = None
self._task_connect_name = None
self._set_task(task) self._set_task(task)
def get_weights(self, raise_on_error=False, force_download=False): def get_weights(self, raise_on_error=False, force_download=False):
@ -1055,6 +1058,7 @@ class BaseModel(object):
def _init_reporter(self): def _init_reporter(self):
if self._reporter: if self._reporter:
return return
self._base_model = self._get_force_base_model()
metrics_manager = Metrics( metrics_manager = Metrics(
session=_Model._get_default_session(), session=_Model._get_default_session(),
storage_uri=None, storage_uri=None,
@ -1126,6 +1130,8 @@ class BaseModel(object):
:return: True if the metadata was set and False otherwise :return: True if the metadata was set and False otherwise
""" """
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_required = ( self._reload_required = (
_Model._get_default_session() _Model._get_default_session()
.send( .send(
@ -1167,6 +1173,8 @@ class BaseModel(object):
:return: String representation of the value of the metadata entry or None if the entry was not found :return: String representation of the value of the metadata entry or None if the entry was not found
""" """
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_if_required() self._reload_if_required()
return self.get_all_metadata().get(str(key), {}).get("value") return self.get_all_metadata().get(str(key), {}).get("value")
@ -1180,6 +1188,8 @@ class BaseModel(object):
:return: The value of the metadata entry, casted to its type (if not possible, :return: The value of the metadata entry, casted to its type (if not possible,
the string representation will be returned) or None if the entry was not found the string representation will be returned) or None if the entry was not found
""" """
if not self._base_model:
self._base_model = self._get_force_base_model()
key = str(key) key = str(key)
metadata = self.get_all_metadata() metadata = self.get_all_metadata()
if key not in metadata: if key not in metadata:
@ -1194,6 +1204,8 @@ class BaseModel(object):
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type :return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
""" """
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_if_required() self._reload_if_required()
return self._get_model_data().metadata or {} return self._get_model_data().metadata or {}
@ -1204,6 +1216,8 @@ class BaseModel(object):
entries are strings. The value is cast to its type if possible. Note that each entry might entries are strings. The value is cast to its type if possible. Note that each entry might
have an additional 'key' entry, repeating the key have an additional 'key' entry, repeating the key
""" """
if not self._base_model:
self._base_model = self._get_force_base_model()
self._reload_if_required() self._reload_if_required()
result = {} result = {}
metadata = self.get_all_metadata() metadata = self.get_all_metadata()
@ -1224,6 +1238,8 @@ class BaseModel(object):
:return: True if the metadata was set and False otherwise :return: True if the metadata was set and False otherwise
""" """
if not self._base_model:
self._base_model = self._get_force_base_model()
metadata_array = [ metadata_array = [
{ {
"key": str(k), "key": str(k),
@ -1249,6 +1265,74 @@ class BaseModel(object):
self._get_base_model().reload() self._get_base_model().reload()
self._reload_required = False self._reload_required = False
def _update_base_model(self, model_name=None, task_model_entry=None):
if not self._task:
return self._base_model
# update the model from the task inputs
labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
config_text = self._task._get_model_config_text()
model_name = (
model_name or self._name or (self._floating_data.name if self._floating_data else None) or self._task.name
)
# noinspection PyBroadException
try:
task_model_entry = (
task_model_entry
or self._task_connect_name
or Path(self._get_model_data().uri).stem
)
except Exception:
pass
parent = self._task.input_models_id.get(task_model_entry)
self._base_model.update(
labels=(self._floating_data.labels if self._floating_data else None) or labels,
design=(self._floating_data.design if self._floating_data else None) or config_text,
task_id=self._task.id,
project_id=self._task.project,
parent_id=parent,
name=model_name,
comment=self._floating_data.comment if self._floating_data else None,
tags=self._floating_data.tags if self._floating_data else None,
framework=self._floating_data.framework if self._floating_data else None,
upload_storage_uri=self._floating_data.upload_storage_uri if self._floating_data else None,
)
# remove model floating change set, by now they should have matched the task.
self._floating_data = None
# now we have to update the creator task so it points to us
if str(self._task.status) not in (
str(self._task.TaskStatusEnum.created),
str(self._task.TaskStatusEnum.in_progress),
):
self._log.warning(
"Could not update last created model in Task {}, "
"Task status '{}' cannot be updated".format(
self._task.id, self._task.status
)
)
elif task_model_entry:
self._base_model.update_for_task(
task_id=self._task.id,
model_id=self.id,
type_="output",
name=task_model_entry,
)
return self._base_model
def _get_force_base_model(self, model_name=None, task_model_entry=None):
if self._base_model:
return self._base_model
if not self._task:
return None
# create a new model from the task
# noinspection PyProtectedMember
self._base_model = self._task._get_output_model(model_id=None)
return self._update_base_model(model_name=model_name, task_model_entry=task_model_entry)
class Model(BaseModel): class Model(BaseModel):
""" """
@ -2060,6 +2144,7 @@ class OutputModel(BaseModel):
self._base_model = None self._base_model = None
self._base_model_id = None self._base_model_id = None
self._task_connect_name = None self._task_connect_name = None
self._name = name
self._label_enumeration = label_enumeration self._label_enumeration = label_enumeration
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._floating_data = create_dummy_model( self._floating_data = create_dummy_model(
@ -2300,7 +2385,11 @@ class OutputModel(BaseModel):
if out_model_file_name if out_model_file_name
else (self._task_connect_name or "Output Model") else (self._task_connect_name or "Output Model")
) )
if not self._base_model:
model = self._get_force_base_model(task_model_entry=name) model = self._get_force_base_model(task_model_entry=name)
else:
self._update_base_model(task_model_entry=name)
model = self._base_model
if not model: if not model:
raise ValueError("Failed creating internal output model") raise ValueError("Failed creating internal output model")
@ -2639,61 +2728,6 @@ class OutputModel(BaseModel):
) )
return weights_filename_offline or register_uri return weights_filename_offline or register_uri
def _get_force_base_model(self, model_name=None, task_model_entry=None):
if self._base_model:
return self._base_model
# create a new model from the task
# noinspection PyProtectedMember
self._base_model = self._task._get_output_model(model_id=None)
# update the model from the task inputs
labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
config_text = self._task._get_model_config_text()
model_name = model_name or self._floating_data.name or self._task.name
task_model_entry = (
task_model_entry
or self._task_connect_name
or Path(self._get_model_data().uri).stem
)
parent = self._task.input_models_id.get(task_model_entry)
self._base_model.update(
labels=self._floating_data.labels or labels,
design=self._floating_data.design or config_text,
task_id=self._task.id,
project_id=self._task.project,
parent_id=parent,
name=model_name,
comment=self._floating_data.comment,
tags=self._floating_data.tags,
framework=self._floating_data.framework,
upload_storage_uri=self._floating_data.upload_storage_uri,
)
# remove model floating change set, by now they should have matched the task.
self._floating_data = None
# now we have to update the creator task so it points to us
if str(self._task.status) not in (
str(self._task.TaskStatusEnum.created),
str(self._task.TaskStatusEnum.in_progress),
):
self._log.warning(
"Could not update last created model in Task {}, "
"Task status '{}' cannot be updated".format(
self._task.id, self._task.status
)
)
else:
self._base_model.update_for_task(
task_id=self._task.id,
model_id=self.id,
type_="output",
name=task_model_entry,
)
return self._base_model
def _get_base_model(self): def _get_base_model(self):
if self._floating_data: if self._floating_data:
return self._floating_data return self._floating_data