Moved archived_tag definition + pep8

This commit is contained in:
allegroai 2020-11-08 00:12:37 +02:00
parent 9b3d107934
commit de85580faa
4 changed files with 87 additions and 16 deletions

View File

@ -26,6 +26,7 @@ from collections import OrderedDict
from six.moves.urllib.parse import quote from six.moves.urllib.parse import quote
from ...utilities.locks import RLock as FileRLock from ...utilities.locks import RLock as FileRLock
from ...utilities.attrs import readonly
from ...binding.artifacts import Artifacts from ...binding.artifacts import Artifacts
from ...backend_interface.task.development.worker import DevWorker from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session from ...backend_api import Session
@ -61,6 +62,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_anonymous_dataview_id = '__anonymous__' _anonymous_dataview_id = '__anonymous__'
_development_tag = 'development' _development_tag = 'development'
archived_tag = readonly('archived')
_default_configuration_section_name = 'General' _default_configuration_section_name = 'General'
_legacy_parameters_section_name = 'Args' _legacy_parameters_section_name = 'Args'
_force_requirements = {} _force_requirements = {}
@ -841,6 +843,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
merged into a single key-value pair dictionary. merged into a single key-value pair dictionary.
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
""" """
def stringify(value):
# return empty string if value is None
if value is None:
return ""
str_value = str(value)
if isinstance(value, (tuple, list, dict)) and 'None' in re.split(r'[ ,\[\]{}()]', str_value):
# If we have None in the string we have to use json to replace it with null,
# otherwise we end up with None as string when running remotely
try:
str_json = json.dumps(value)
# verify we actually have a null in the string, otherwise prefer the str cast
# This is because we prefer to have \' as in str and not \" used in json
if 'null' in re.split(r'[ ,\[\]{}()]', str_json):
return str_json
except TypeError:
# if we somehow failed to json serialize, revert to previous std casting
pass
return str_value
if not all(isinstance(x, (dict, Iterable)) for x in args): if not all(isinstance(x, (dict, Iterable)) for x in args):
raise ValueError('only dict or iterable are supported as positional arguments') raise ValueError('only dict or iterable are supported as positional arguments')
@ -886,7 +908,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
parameters.update(new_parameters) parameters.update(new_parameters)
# force cast all variables to strings (so that we can later edit them in UI) # force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} parameters = {k: stringify(v) for k, v in parameters.items()}
if use_hyperparams: if use_hyperparams:
# build nested dict from flat parameters dict: # build nested dict from flat parameters dict:
@ -971,6 +993,25 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
params = self.get_parameters() params = self.get_parameters()
return params.get(name, default) return params.get(name, default)
def delete_parameter(self, name):
# type: (str) -> bool
"""
Delete a parameter byt it's full name Section/name.
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
:return: True if the parameter was deleted successfully
"""
if not Session.check_min_api_version('2.9'):
raise ValueError("Delete hyper parameter is not supported by your trains-server, "
"upgrade to the latest version")
with self._edit_lock:
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
res = self.send(tasks.DeleteHyperParamsRequest(
task=self.id, hyperparams=[paramkey]), raise_on_errors=False)
self.reload()
return res.ok()
def update_parameters(self, *args, **kwargs): def update_parameters(self, *args, **kwargs):
# type: (*dict, **Any) -> () # type: (*dict, **Any) -> ()
""" """
@ -1678,6 +1719,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
extra['hyperparams'] = task.hyperparams extra['hyperparams'] = task.hyperparams
if hasattr(task, 'configuration'): if hasattr(task, 'configuration'):
extra['configuration'] = task.configuration extra['configuration'] = task.configuration
if getattr(task, 'system_tags', None):
extra['system_tags'] = [t for t in task.system_tags if t not in (cls._development_tag, cls.archived_tag)]
req = tasks.CreateRequest( req = tasks.CreateRequest(
name=name or task.name, name=name or task.name,

View File

@ -319,7 +319,8 @@ class Artifacts(object):
if preview: if preview:
preview = str(preview) preview = str(preview)
# convert string to object if try is a file/folder (dont try to serialize long texts # try to convert string Path object (it might reference a file/folder)
# dont not try to serialize long texts.
if isinstance(artifact_object, six.string_types) and len(artifact_object) < 2048: if isinstance(artifact_object, six.string_types) and len(artifact_object) < 2048:
# noinspection PyBroadException # noinspection PyBroadException
try: try:

View File

@ -26,8 +26,6 @@ from .config import running_remotely, get_cache_dir
if TYPE_CHECKING: if TYPE_CHECKING:
from .task import Task from .task import Task
ARCHIVED_TAG = "archived"
class Framework(Options): class Framework(Options):
""" """
@ -77,7 +75,7 @@ class Framework(Options):
'.cfg': (darknet, ), '.cfg': (darknet, ),
'__model__': (paddlepaddle, ), '__model__': (paddlepaddle, ),
'.pkl': (scikitlearn, keras, xgboost), '.pkl': (scikitlearn, keras, xgboost),
'.parquet': (parquet), '.parquet': (parquet, ),
} }
@classmethod @classmethod
@ -90,7 +88,7 @@ class Framework(Options):
if frameworks and filename.endswith(ext): if frameworks and filename.endswith(ext):
fw = framework_selector(frameworks) fw = framework_selector(frameworks)
if fw: if fw:
return (fw, ext) return fw, ext
# If no framework, try finding first framework matching the extension, otherwise (or if no match) try matching # If no framework, try finding first framework matching the extension, otherwise (or if no match) try matching
# the given extension to the given framework. If no match return an empty extension # the given extension to the given framework. If no match return an empty extension
@ -103,6 +101,8 @@ class Framework(Options):
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class BaseModel(object): class BaseModel(object):
# noinspection PyProtectedMember
_archived_tag = _Task.archived_tag
_package_tag = "package" _package_tag = "package"
@property @property
@ -208,6 +208,7 @@ class BaseModel(object):
:return: The configuration. :return: The configuration.
""" """
# noinspection PyProtectedMember
return _Model._unwrap_design(self._get_model_data().design) return _Model._unwrap_design(self._get_model_data().design)
@property @property
@ -234,11 +235,11 @@ class BaseModel(object):
@property @property
def task(self): def task(self):
# type: () -> str # type: () -> _Task
""" """
Return the creating task id (str) Return the creating task object
:return: The Task ID. :return: The Task object.
""" """
return self._task or self._get_base_model().task return self._task or self._get_base_model().task
@ -446,6 +447,7 @@ class InputModel(Model):
We can connect the model to a task as input model, then when running remotely override it with the UI. We can connect the model to a task as input model, then when running remotely override it with the UI.
""" """
# noinspection PyProtectedMember
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID _EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
@classmethod @classmethod
@ -512,7 +514,7 @@ class InputModel(Model):
:param tags: The list of tags which describe the model. (Optional) :param tags: The list of tags which describe the model. (Optional)
:type tags: list(str) :type tags: list(str)
:param str comment: A comment / description for the model. (Optional) :param str comment: A comment / description for the model. (Optional)
:type comment str: :type comment: str
:param is_package: Is the imported weights file is a package (Optional) :param is_package: Is the imported weights file is a package (Optional)
- ``True`` - Is a package. Add a package tag to the model. - ``True`` - Is a package. Add a package tag to the model.
@ -536,8 +538,9 @@ class InputModel(Model):
# convert local to file to remote one # convert local to file to remote one
weights_url = CacheManager.get_remote_url(weights_url) weights_url = CacheManager.get_remote_url(weights_url)
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \ extra = {'system_tags': ["-" + cls.archived_tag]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]} if Session.check_min_api_version('2.3') else {'tags': ["-" + cls.archived_tag]}
# noinspection PyProtectedMember
result = _Model._get_default_session().send(models.GetAllRequest( result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url], uri=[weights_url],
only_fields=["id", "name", "created"], only_fields=["id", "name", "created"],
@ -580,6 +583,7 @@ class InputModel(Model):
task_id = None task_id = None
if not framework: if not framework:
# noinspection PyProtectedMember
framework, file_ext = Framework._get_file_ext( framework, file_ext = Framework._get_file_ext(
framework=framework, framework=framework,
filename=weights_url filename=weights_url
@ -642,11 +646,13 @@ class InputModel(Model):
weights_url = CacheManager.get_remote_url(weights_url) weights_url = CacheManager.get_remote_url(weights_url)
if not load_archived: if not load_archived:
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \ # noinspection PyTypeChecker
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]} extra = {'system_tags': ["-" + _Task.archived_tag]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + cls._archived_tag]}
else: else:
extra = {} extra = {}
# noinspection PyProtectedMember
result = _Model._get_default_session().send(models.GetAllRequest( result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url], uri=[weights_url],
only_fields=["id", "name", "created"], only_fields=["id", "name", "created"],
@ -700,7 +706,9 @@ class InputModel(Model):
upload_storage_uri=None, upload_storage_uri=None,
model_id=cls._EMPTY_MODEL_ID, model_id=cls._EMPTY_MODEL_ID,
) )
# noinspection PyProtectedMember
m._data.design = _Model._wrap_design(design) m._data.design = _Model._wrap_design(design)
# noinspection PyProtectedMember
m._data.labels = label_enumeration m._data.labels = label_enumeration
return this_model return this_model
@ -749,13 +757,16 @@ class InputModel(Model):
if model.id != self._EMPTY_MODEL_ID: if model.id != self._EMPTY_MODEL_ID:
task.set_input_model(model_id=model.id) task.set_input_model(model_id=model.id)
# only copy the model design if the task has no design to begin with # only copy the model design if the task has no design to begin with
# noinspection PyProtectedMember
if not self._task._get_model_config_text(): if not self._task._get_model_config_text():
# noinspection PyProtectedMember
task._set_model_config(config_text=model.model_design) task._set_model_config(config_text=model.model_design)
if not self._task.get_labels_enumeration(): if not self._task.get_labels_enumeration():
task.set_model_label_enumeration(model.data.labels) task.set_model_label_enumeration(model.data.labels)
# If there was an output model connected, it may need to be updated by # If there was an output model connected, it may need to be updated by
# the newly connected input model # the newly connected input model
# noinspection PyProtectedMember
self.task._reconnect_output_model() self.task._reconnect_output_model()
@ -801,6 +812,7 @@ class OutputModel(BaseModel):
:return: The configuration. :return: The configuration.
""" """
# noinspection PyProtectedMember
return _Model._unwrap_design(self._get_model_data().design) return _Model._unwrap_design(self._get_model_data().design)
@config_text.setter @config_text.setter
@ -933,6 +945,7 @@ class OutputModel(BaseModel):
self._model_local_filename = None self._model_local_filename = None
self._base_model = None self._base_model = None
# noinspection PyProtectedMember
self._floating_data = create_dummy_model( self._floating_data = create_dummy_model(
design=_Model._wrap_design(config_text), design=_Model._wrap_design(config_text),
labels=label_enumeration or task.get_labels_enumeration(), labels=label_enumeration or task.get_labels_enumeration(),
@ -944,7 +957,9 @@ class OutputModel(BaseModel):
upload_storage_uri=task.output_uri, upload_storage_uri=task.output_uri,
) )
if base_model_id: if base_model_id:
# noinspection PyBroadException
try: try:
# noinspection PyProtectedMember
_base_model = self._task._get_output_model(model_id=base_model_id) _base_model = self._task._get_output_model(model_id=base_model_id)
_base_model.update( _base_model.update(
labels=self._floating_data.labels, labels=self._floating_data.labels,
@ -983,24 +998,30 @@ class OutputModel(BaseModel):
if running_remotely() and task.is_main_task(): if running_remotely() and task.is_main_task():
if self._floating_data: if self._floating_data:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \ self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \
self._floating_data.design self._floating_data.design
self._floating_data.labels = self._task.get_labels_enumeration() or \ self._floating_data.labels = self._task.get_labels_enumeration() or \
self._floating_data.labels self._floating_data.labels
elif self._base_model: elif self._base_model:
# noinspection PyProtectedMember
self._base_model.update(design=_Model._wrap_design(self._task._get_model_config_text()) or self._base_model.update(design=_Model._wrap_design(self._task._get_model_config_text()) or
self._base_model.design) self._base_model.design)
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels) self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
elif self._floating_data is not None: elif self._floating_data is not None:
# we copy configuration / labels if they exist, obviously someone wants them as the output base model # we copy configuration / labels if they exist, obviously someone wants them as the output base model
# noinspection PyProtectedMember
design = _Model._unwrap_design(self._floating_data.design) design = _Model._unwrap_design(self._floating_data.design)
if design: if design:
# noinspection PyProtectedMember
if not task._get_model_config_text(): if not task._get_model_config_text():
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):
design = self._floating_data.design design = self._floating_data.design
# noinspection PyProtectedMember
task._set_model_config(config_text=design) task._set_model_config(config_text=design)
else: else:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text())
if self._floating_data.labels: if self._floating_data.labels:
@ -1008,6 +1029,7 @@ class OutputModel(BaseModel):
else: else:
self._floating_data.labels = self._task.get_labels_enumeration() self._floating_data.labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
self.task._save_output_model(self) self.task._save_output_model(self)
def set_upload_destination(self, uri): def set_upload_destination(self, uri):
@ -1128,6 +1150,7 @@ class OutputModel(BaseModel):
# select the correct file extension based on the framework, # select the correct file extension based on the framework,
# or update the framework based on the file extension # or update the framework based on the file extension
# noinspection PyProtectedMember
framework, file_ext = Framework._get_file_ext( framework, file_ext = Framework._get_file_ext(
framework=self._get_model_data().framework, framework=self._get_model_data().framework,
filename=target_filename or weights_filename or register_uri filename=target_filename or weights_filename or register_uri
@ -1184,6 +1207,7 @@ class OutputModel(BaseModel):
self._set_package_tag() self._set_package_tag()
# make sure that if we are in dev move we report that we are training (not debugging) # make sure that if we are in dev move we report that we are training (not debugging)
# noinspection PyProtectedMember
self._task._output_model_updated() self._task._output_model_updated()
return output_uri return output_uri
@ -1299,6 +1323,7 @@ class OutputModel(BaseModel):
# update the model object (this will happen if we resumed a training task) # update the model object (this will happen if we resumed a training task)
result = self._get_force_base_model().edit(design=config_text) result = self._get_force_base_model().edit(design=config_text)
else: else:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(config_text) self._floating_data.design = _Model._wrap_design(config_text)
result = Waitable() result = Waitable()
@ -1361,6 +1386,7 @@ class OutputModel(BaseModel):
self._base_model = self._task.create_output_model() self._base_model = self._task.create_output_model()
# update the model from the task inputs # update the model from the task inputs
labels = self._task.get_labels_enumeration() labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
config_text = self._task._get_model_config_text() config_text = self._task._get_model_config_text()
parent = self._task.output_model_id or self._task.input_model_id parent = self._task.output_model_id or self._task.input_model_id
self._base_model.update( self._base_model.update(

View File

@ -48,8 +48,9 @@ from .config.cache import SessionCache
from .debugging.log import LoggerRoot from .debugging.log import LoggerRoot
from .errors import UsageError from .errors import UsageError
from .logger import Logger from .logger import Logger
from .model import Model, InputModel, OutputModel, ARCHIVED_TAG from .model import Model, InputModel, OutputModel
from .task_parameters import TaskParameters from .task_parameters import TaskParameters
from .utilities.config import verify_basic_value
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask argparser_update_currenttask
from .utilities.dicts import ReadOnlyDict, merge_dicts from .utilities.dicts import ReadOnlyDict, merge_dicts
@ -2075,7 +2076,7 @@ class Task(_Task):
if hasattr(task.data.execution, 'artifacts') else None if hasattr(task.data.execution, 'artifacts') else None
if ((str(task._status) in ( if ((str(task._status) in (
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or task.output_model_id or (ARCHIVED_TAG in task_tags) or task.output_model_id or (cls.archived_tag in task_tags)
or (cls._development_tag not in task_tags) or (cls._development_tag not in task_tags)
or task_artifacts): or task_artifacts):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode