Support querying model metadata in Model.query_models()

This commit is contained in:
allegroai 2022-12-22 21:59:24 +02:00
parent 8ba78b5a11
commit 8992275f8e

View File

@ -448,6 +448,117 @@ class BaseModel(object):
return config_text
def set_metadata(self, key, value, v_type=None):
# type: (str, str, Optional[str]) -> bool
"""
Set one metadata entry. All parameters must be strings or castable to strings
:param key: Key of the metadata entry
:param value: Value of the metadata entry
:param v_type: Type of the metadata entry
:return: True if the metadata was set and False otherwise
"""
self._reload_required = (
_Model._get_default_session()
.send(
models.AddOrUpdateMetadataRequest(
metadata=[{
"key": str(key),
"value": str(value),
"type": str(v_type)
if v_type in (float, int, bool, six.string_types, list, tuple, dict) else
str(None)
}],
model=self.id,
replace_metadata=False,
)
)
.ok()
)
return self._reload_required
def get_metadata(self, key):
# type: (str) -> Optional[str]
"""
Get one metadata entry value (as a string) based on its key. See `Model.get_metadata_casted`
if you wish to cast the value to its type (if possible)
:param key: Key of the metadata entry you want to get
:return: String representation of the value of the metadata entry or None if the entry was not found
"""
self._reload_if_required()
return self.get_all_metadata().get(str(key), {}).get("value")
def get_metadata_casted(self, key):
# type: (str) -> Optional[str]
"""
Get one metadata entry based on its key, casted to its type if possible
:param key: Key of the metadata entry you want to get
:return: The value of the metadata entry, casted to its type (if not possible,
the string representation will be returned) or None if the entry was not found
"""
key = str(key)
metadata = self.get_all_metadata()
if key not in metadata:
return None
return cast_basic_type(metadata[key].get("value"), metadata[key].get("type"))
def get_all_metadata(self):
# type: () -> Dict[str, Dict[str, str]]
"""
See `Model.get_all_metadata_casted` if you wish to cast the value to its type (if possible)
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
"""
self._reload_if_required()
return self._get_model_data().metadata or {}
def get_all_metadata_casted(self):
# type: () -> Dict[str, Dict[str, Any]]
"""
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key and type
entries are strings. The value is cast to its type if possible. Note that each entry might
have an additional 'key' entry, repeating the key
"""
self._reload_if_required()
result = {}
metadata = self.get_all_metadata()
for key, metadata_entry in metadata.items():
result[key] = cast_basic_type(metadata_entry.get("value"), metadata_entry.get("type"))
return result
def set_all_metadata(self, metadata, replace=True):
# type: (Dict[str, Dict[str, str]], bool) -> bool
"""
Set metadata based on the given parameters. Allows replacing all entries or updating the current entries.
:param metadata: A dictionary of format Dict[key, Dict[value, type]] representing the metadata you want to set
:param replace: If True, replace all metadata with the entries in the `metadata` parameter. If False,
keep the old metadata and update it with the entries in the `metadata` parameter (add or change it)
:return: True if the metadata was set and False otherwise
"""
metadata_array = [
{"key": str(k), "value": str(v_t.get("value")), "type": str(v_t.get("type"))} for k, v_t in metadata.items()
]
self._reload_required = (
_Model._get_default_session()
.send(models.AddOrUpdateMetadataRequest(metadata=metadata_array, model=self.id, replace_metadata=replace))
.ok()
)
return self._reload_required
def _reload_if_required(self):
if not self._reload_required:
return
self._get_base_model().reload()
self._reload_required = False
class Model(BaseModel):
"""
@ -467,7 +578,6 @@ class Model(BaseModel):
super(Model, self).__init__()
self._base_model_id = model_id
self._base_model = None
self._reload_required = False
def get_local_copy(self, extract_archive=True, raise_on_error=False):
# type: (bool, bool) -> str
@ -514,6 +624,7 @@ class Model(BaseModel):
only_published=False, # type: bool
include_archived=False, # type: bool
max_results=None, # type: Optional[int]
metadata=None # type: Optional[Dict[str, str]]
):
# type: (...) -> List[Model]
"""
@ -529,6 +640,8 @@ class Model(BaseModel):
:param include_archived: If True return archived models.
:param max_results: Optional return the last X models,
sorted by last update time (from the most recent to the least).
:param metadata: Filter based on metadata. This parameter is a dictionary. Notice that the type of the
metadata field is not required.
:return: ModeList of Models objects
"""
@ -546,24 +659,40 @@ class Model(BaseModel):
only_fields = ['id', 'created', 'system_tags']
# noinspection PyProtectedMember
res = _Model._get_default_session().send(
models.GetAllRequest(
project=[project.id] if project else None,
name=exact_match_regex(model_name) if model_name is not None else None,
only_fields=only_fields,
tags=tags or None,
system_tags=["-" + cls._archived_tag] if not include_archived else None,
ready=True if only_published else None,
order_by=['-created'],
page=0 if max_results else None,
page_size=max_results or None,
)
)
if not res.response.models:
return []
extra_fields = {"metadata.{}.value".format(k): v for k, v in (metadata or {}).items()}
return [Model(model_id=m.id) for m in res.response.models]
models_fetched = []
page = 0
page_size = 500
results_left = max_results if max_results is not None else float("inf")
while True:
# noinspection PyProtectedMember
res = _Model._get_default_session().send(
models.GetAllRequest(
project=[project.id] if project else None,
name=exact_match_regex(model_name) if model_name is not None else None,
only_fields=only_fields,
tags=tags or None,
system_tags=["-" + cls._archived_tag] if not include_archived else None,
ready=True if only_published else None,
order_by=['-created'],
page=page,
page_size=page_size if results_left > page_size else results_left,
_allow_extra_fields_=True,
**extra_fields
)
)
if not res.response.models:
break
models_fetched.extend(res.response.models)
results_left -= len(res.response.models)
if results_left <= 0 or len(res.response.models) < page_size:
break
page += 1
return [Model(model_id=m.id) for m in models_fetched]
@property
def id(self):
@ -635,108 +764,6 @@ class Model(BaseModel):
return True
def set_metadata(self, key, value, type):
# type: (str, str, str) -> bool
"""
Set one metadata entry. All parameters must be strings or castable to strings
:param key: Key of the metadata entry
:param value: Value of the metadata entry
:param type: Type of the metadata entry
:return: True if the metadata was set and False otherwise
"""
self._reload_required = True
result = _Model._get_default_session().send(models.AddOrUpdateMetadataRequest(
metadata=[{"key": str(key), "value": str(value), "type": str(type)}],
model=self._base_model_id,
replace_metadata=False
))
return bool(result)
def get_metadata(self, key):
# type: (str) -> Optional[str]
"""
Get one metadata entry value (as a string) based on its key. See `Model.get_metadata_casted`
if you wish to cast the value to its type (if possible)
:param key: Key of the metadata entry you want to get
:return: String representation of the value of the metadata entry or None if the entry was not found
"""
self._reload_if_required()
return self.get_all_metadata().get(str(key), {}).get("value")
def get_metadata_casted(self, key):
# type: (str) -> Optional[str]
"""
Get one metadata entry based on its key, casted to its type if possible
:param key: Key of the metadata entry you want to get
:return: The value of the metadata entry, casted to its type (if not possible,
the string representation will be returned) or None if the entry was not found
"""
key = str(key)
metadata = self.get_all_metadata()
if key not in metadata:
return None
return cast_basic_type(metadata[key].get("value"), metadata[key].get("type"))
def get_all_metadata(self):
# type: () -> Dict[str, Dict[str, str]]
"""
See `Model.get_all_metadata_casted` if you wish to cast the value to its type (if possible)
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key, value and type
entries are all strings. Note that each entry might have an additional 'key' entry, repeating the key
"""
self._reload_if_required()
return self._get_model_data().metadata or {}
def get_all_metadata_casted(self):
# type: () -> Dict[str, Dict[str, Any]]
"""
:return: Get all metadata as a dictionary of format Dict[key, Dict[value, type]]. The key and type
entries are strings. The value is cast to its type if possible. Note that each entry might
have an additional 'key' entry, repeating the key
"""
self._reload_if_required()
result = {}
metadata = self.get_all_metadata()
for key, metadata_entry in metadata.items():
result[key] = cast_basic_type(metadata_entry.get("value"), metadata_entry.get("type"))
return result
def set_all_metadata(self, metadata, replace=True):
# type: (Dict[str, Dict[str, str]], bool) -> bool
"""
Set metadata based on the given parameters. Allows replacing all entries or updating the current entries.
:param metadata: A dictionary of format Dict[key, Dict[value, type]] representing the metadata you want to set
:param replace: If True, replace all metadata with the entries in the `metadata` parameter. If False,
keep the old metadata and update it with the entries in the `metadata` parameter (add or change it)
:return: True if the metadata was set and False otherwise
"""
self._reload_required = True
metadata_array = [
{"key": str(k), "value": str(v_t.get("value")), "type": str(v_t.get("type"))} for k, v_t in metadata.items()
]
result = _Model._get_default_session().send(models.AddOrUpdateMetadataRequest(
metadata=metadata_array,
model=self._base_model_id,
replace_metadata=replace
))
return bool(result)
def _reload_if_required(self):
if not self._base_model:
self._get_base_model()
if self._reload_required:
self._base_model.reload()
self._reload_required = False
class InputModel(Model):
"""