Moved archived_tag definition + pep8

This commit is contained in:
allegroai
2020-11-08 00:12:37 +02:00
parent 9b3d107934
commit de85580faa
4 changed files with 87 additions and 16 deletions

View File

@@ -26,6 +26,7 @@ from collections import OrderedDict
from six.moves.urllib.parse import quote
from ...utilities.locks import RLock as FileRLock
from ...utilities.attrs import readonly
from ...binding.artifacts import Artifacts
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
@@ -61,6 +62,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_anonymous_dataview_id = '__anonymous__'
_development_tag = 'development'
archived_tag = readonly('archived')
_default_configuration_section_name = 'General'
_legacy_parameters_section_name = 'Args'
_force_requirements = {}
@@ -841,6 +843,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
merged into a single key-value pair dictionary.
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
"""
def stringify(value):
# return empty string if value is None
if value is None:
return ""
str_value = str(value)
if isinstance(value, (tuple, list, dict)) and 'None' in re.split(r'[ ,\[\]{}()]', str_value):
# If we have None in the string we have to use json to replace it with null,
# otherwise we end up with None as string when running remotely
try:
str_json = json.dumps(value)
# verify we actually have a null in the string, otherwise prefer the str cast
# This is because we prefer to have \' as in str and not \" used in json
if 'null' in re.split(r'[ ,\[\]{}()]', str_json):
return str_json
except TypeError:
# if we somehow failed to json serialize, revert to previous std casting
pass
return str_value
if not all(isinstance(x, (dict, Iterable)) for x in args):
raise ValueError('only dict or iterable are supported as positional arguments')
@@ -886,7 +908,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
parameters.update(new_parameters)
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
parameters = {k: stringify(v) for k, v in parameters.items()}
if use_hyperparams:
# build nested dict from flat parameters dict:
@@ -971,6 +993,25 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
params = self.get_parameters()
return params.get(name, default)
def delete_parameter(self, name):
# type: (str) -> bool
"""
Delete a parameter byt it's full name Section/name.
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
:return: True if the parameter was deleted successfully
"""
if not Session.check_min_api_version('2.9'):
raise ValueError("Delete hyper parameter is not supported by your trains-server, "
"upgrade to the latest version")
with self._edit_lock:
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
res = self.send(tasks.DeleteHyperParamsRequest(
task=self.id, hyperparams=[paramkey]), raise_on_errors=False)
self.reload()
return res.ok()
def update_parameters(self, *args, **kwargs):
# type: (*dict, **Any) -> ()
"""
@@ -1678,6 +1719,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
extra['hyperparams'] = task.hyperparams
if hasattr(task, 'configuration'):
extra['configuration'] = task.configuration
if getattr(task, 'system_tags', None):
extra['system_tags'] = [t for t in task.system_tags if t not in (cls._development_tag, cls.archived_tag)]
req = tasks.CreateRequest(
name=name or task.name,