Moved archived_tag definition + pep8

This commit is contained in:
allegroai 2020-11-08 00:12:37 +02:00
parent 9b3d107934
commit de85580faa
4 changed files with 87 additions and 16 deletions

View File

@ -26,6 +26,7 @@ from collections import OrderedDict
from six.moves.urllib.parse import quote
from ...utilities.locks import RLock as FileRLock
from ...utilities.attrs import readonly
from ...binding.artifacts import Artifacts
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
@ -61,6 +62,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_anonymous_dataview_id = '__anonymous__'
_development_tag = 'development'
archived_tag = readonly('archived')
_default_configuration_section_name = 'General'
_legacy_parameters_section_name = 'Args'
_force_requirements = {}
@ -841,6 +843,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
merged into a single key-value pair dictionary.
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
"""
def stringify(value):
# return empty string if value is None
if value is None:
return ""
str_value = str(value)
if isinstance(value, (tuple, list, dict)) and 'None' in re.split(r'[ ,\[\]{}()]', str_value):
# If we have None in the string we have to use json to replace it with null,
# otherwise we end up with None as string when running remotely
try:
str_json = json.dumps(value)
# verify we actually have a null in the string, otherwise prefer the str cast
# This is because we prefer to have \' as in str and not \" used in json
if 'null' in re.split(r'[ ,\[\]{}()]', str_json):
return str_json
except TypeError:
# if we somehow failed to json serialize, revert to previous std casting
pass
return str_value
if not all(isinstance(x, (dict, Iterable)) for x in args):
raise ValueError('only dict or iterable are supported as positional arguments')
@ -886,7 +908,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
parameters.update(new_parameters)
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
parameters = {k: stringify(v) for k, v in parameters.items()}
if use_hyperparams:
# build nested dict from flat parameters dict:
@ -971,6 +993,25 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
params = self.get_parameters()
return params.get(name, default)
def delete_parameter(self, name):
# type: (str) -> bool
"""
Delete a parameter byt it's full name Section/name.
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
:return: True if the parameter was deleted successfully
"""
if not Session.check_min_api_version('2.9'):
raise ValueError("Delete hyper parameter is not supported by your trains-server, "
"upgrade to the latest version")
with self._edit_lock:
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
res = self.send(tasks.DeleteHyperParamsRequest(
task=self.id, hyperparams=[paramkey]), raise_on_errors=False)
self.reload()
return res.ok()
def update_parameters(self, *args, **kwargs):
# type: (*dict, **Any) -> ()
"""
@ -1678,6 +1719,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
extra['hyperparams'] = task.hyperparams
if hasattr(task, 'configuration'):
extra['configuration'] = task.configuration
if getattr(task, 'system_tags', None):
extra['system_tags'] = [t for t in task.system_tags if t not in (cls._development_tag, cls.archived_tag)]
req = tasks.CreateRequest(
name=name or task.name,

View File

@ -319,7 +319,8 @@ class Artifacts(object):
if preview:
preview = str(preview)
# convert string to object if try is a file/folder (dont try to serialize long texts
# try to convert string Path object (it might reference a file/folder)
# dont not try to serialize long texts.
if isinstance(artifact_object, six.string_types) and len(artifact_object) < 2048:
# noinspection PyBroadException
try:

View File

@ -26,8 +26,6 @@ from .config import running_remotely, get_cache_dir
if TYPE_CHECKING:
from .task import Task
ARCHIVED_TAG = "archived"
class Framework(Options):
"""
@ -77,7 +75,7 @@ class Framework(Options):
'.cfg': (darknet, ),
'__model__': (paddlepaddle, ),
'.pkl': (scikitlearn, keras, xgboost),
'.parquet': (parquet),
'.parquet': (parquet, ),
}
@classmethod
@ -90,7 +88,7 @@ class Framework(Options):
if frameworks and filename.endswith(ext):
fw = framework_selector(frameworks)
if fw:
return (fw, ext)
return fw, ext
# If no framework, try finding first framework matching the extension, otherwise (or if no match) try matching
# the given extension to the given framework. If no match return an empty extension
@ -103,6 +101,8 @@ class Framework(Options):
@six.add_metaclass(abc.ABCMeta)
class BaseModel(object):
# noinspection PyProtectedMember
_archived_tag = _Task.archived_tag
_package_tag = "package"
@property
@ -208,6 +208,7 @@ class BaseModel(object):
:return: The configuration.
"""
# noinspection PyProtectedMember
return _Model._unwrap_design(self._get_model_data().design)
@property
@ -234,11 +235,11 @@ class BaseModel(object):
@property
def task(self):
# type: () -> str
# type: () -> _Task
"""
Return the creating task id (str)
Return the creating task object
:return: The Task ID.
:return: The Task object.
"""
return self._task or self._get_base_model().task
@ -446,6 +447,7 @@ class InputModel(Model):
We can connect the model to a task as input model, then when running remotely override it with the UI.
"""
# noinspection PyProtectedMember
_EMPTY_MODEL_ID = _Model._EMPTY_MODEL_ID
@classmethod
@ -512,7 +514,7 @@ class InputModel(Model):
:param tags: The list of tags which describe the model. (Optional)
:type tags: list(str)
:param str comment: A comment / description for the model. (Optional)
:type comment str:
:type comment: str
:param is_package: Is the imported weights file is a package (Optional)
- ``True`` - Is a package. Add a package tag to the model.
@ -536,8 +538,9 @@ class InputModel(Model):
# convert local to file to remote one
weights_url = CacheManager.get_remote_url(weights_url)
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
extra = {'system_tags': ["-" + cls.archived_tag]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + cls.archived_tag]}
# noinspection PyProtectedMember
result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url],
only_fields=["id", "name", "created"],
@ -580,6 +583,7 @@ class InputModel(Model):
task_id = None
if not framework:
# noinspection PyProtectedMember
framework, file_ext = Framework._get_file_ext(
framework=framework,
filename=weights_url
@ -642,11 +646,13 @@ class InputModel(Model):
weights_url = CacheManager.get_remote_url(weights_url)
if not load_archived:
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
# noinspection PyTypeChecker
extra = {'system_tags': ["-" + _Task.archived_tag]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + cls._archived_tag]}
else:
extra = {}
# noinspection PyProtectedMember
result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url],
only_fields=["id", "name", "created"],
@ -700,7 +706,9 @@ class InputModel(Model):
upload_storage_uri=None,
model_id=cls._EMPTY_MODEL_ID,
)
# noinspection PyProtectedMember
m._data.design = _Model._wrap_design(design)
# noinspection PyProtectedMember
m._data.labels = label_enumeration
return this_model
@ -749,13 +757,16 @@ class InputModel(Model):
if model.id != self._EMPTY_MODEL_ID:
task.set_input_model(model_id=model.id)
# only copy the model design if the task has no design to begin with
# noinspection PyProtectedMember
if not self._task._get_model_config_text():
# noinspection PyProtectedMember
task._set_model_config(config_text=model.model_design)
if not self._task.get_labels_enumeration():
task.set_model_label_enumeration(model.data.labels)
# If there was an output model connected, it may need to be updated by
# the newly connected input model
# noinspection PyProtectedMember
self.task._reconnect_output_model()
@ -801,6 +812,7 @@ class OutputModel(BaseModel):
:return: The configuration.
"""
# noinspection PyProtectedMember
return _Model._unwrap_design(self._get_model_data().design)
@config_text.setter
@ -933,6 +945,7 @@ class OutputModel(BaseModel):
self._model_local_filename = None
self._base_model = None
# noinspection PyProtectedMember
self._floating_data = create_dummy_model(
design=_Model._wrap_design(config_text),
labels=label_enumeration or task.get_labels_enumeration(),
@ -944,7 +957,9 @@ class OutputModel(BaseModel):
upload_storage_uri=task.output_uri,
)
if base_model_id:
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
_base_model = self._task._get_output_model(model_id=base_model_id)
_base_model.update(
labels=self._floating_data.labels,
@ -983,24 +998,30 @@ class OutputModel(BaseModel):
if running_remotely() and task.is_main_task():
if self._floating_data:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \
self._floating_data.design
self._floating_data.labels = self._task.get_labels_enumeration() or \
self._floating_data.labels
elif self._base_model:
# noinspection PyProtectedMember
self._base_model.update(design=_Model._wrap_design(self._task._get_model_config_text()) or
self._base_model.design)
self._base_model.update(labels=self._task.get_labels_enumeration() or self._base_model.labels)
elif self._floating_data is not None:
# we copy configuration / labels if they exist, obviously someone wants them as the output base model
# noinspection PyProtectedMember
design = _Model._unwrap_design(self._floating_data.design)
if design:
# noinspection PyProtectedMember
if not task._get_model_config_text():
if not Session.check_min_api_version('2.9'):
design = self._floating_data.design
# noinspection PyProtectedMember
task._set_model_config(config_text=design)
else:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text())
if self._floating_data.labels:
@ -1008,6 +1029,7 @@ class OutputModel(BaseModel):
else:
self._floating_data.labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
self.task._save_output_model(self)
def set_upload_destination(self, uri):
@ -1128,6 +1150,7 @@ class OutputModel(BaseModel):
# select the correct file extension based on the framework,
# or update the framework based on the file extension
# noinspection PyProtectedMember
framework, file_ext = Framework._get_file_ext(
framework=self._get_model_data().framework,
filename=target_filename or weights_filename or register_uri
@ -1184,6 +1207,7 @@ class OutputModel(BaseModel):
self._set_package_tag()
# make sure that if we are in dev move we report that we are training (not debugging)
# noinspection PyProtectedMember
self._task._output_model_updated()
return output_uri
@ -1299,6 +1323,7 @@ class OutputModel(BaseModel):
# update the model object (this will happen if we resumed a training task)
result = self._get_force_base_model().edit(design=config_text)
else:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(config_text)
result = Waitable()
@ -1361,6 +1386,7 @@ class OutputModel(BaseModel):
self._base_model = self._task.create_output_model()
# update the model from the task inputs
labels = self._task.get_labels_enumeration()
# noinspection PyProtectedMember
config_text = self._task._get_model_config_text()
parent = self._task.output_model_id or self._task.input_model_id
self._base_model.update(

View File

@ -48,8 +48,9 @@ from .config.cache import SessionCache
from .debugging.log import LoggerRoot
from .errors import UsageError
from .logger import Logger
from .model import Model, InputModel, OutputModel, ARCHIVED_TAG
from .model import Model, InputModel, OutputModel
from .task_parameters import TaskParameters
from .utilities.config import verify_basic_value
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
from .utilities.dicts import ReadOnlyDict, merge_dicts
@ -2075,7 +2076,7 @@ class Task(_Task):
if hasattr(task.data.execution, 'artifacts') else None
if ((str(task._status) in (
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or task.output_model_id or (ARCHIVED_TAG in task_tags)
or task.output_model_id or (cls.archived_tag in task_tags)
or (cls._development_tag not in task_tags)
or task_artifacts):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode