Add support for Task hyper-parameter sections and meta-data

Add new Task configuration section
This commit is contained in:
allegroai
2020-08-10 08:45:25 +03:00
parent 42ba696518
commit 8c7e230898
14 changed files with 1076 additions and 107 deletions

View File

@@ -49,13 +49,13 @@ class TaskSystemTags(object):
development = "development"
class Script(EmbeddedDocument):
class Script(EmbeddedDocument, ProperDictMixin):
binary = StringField(default="python")
repository = StringField(required=True)
repository = StringField(default="")
tag = StringField()
branch = StringField()
version_num = StringField()
entry_point = StringField(required=True)
entry_point = StringField(default="")
working_dir = StringField()
requirements = SafeDictField()
diff = StringField()
@@ -84,6 +84,21 @@ class Artifact(EmbeddedDocument):
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
class ParamsItem(EmbeddedDocument, ProperDictMixin):
section = StringField(required=True)
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
@@ -116,9 +131,12 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True},
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
}
meta = {
@@ -187,7 +205,7 @@ class Task(AttributedDocument):
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script: Script = EmbeddedDocumentField(Script)
script: Script = EmbeddedDocumentField(Script, default=Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
@@ -196,3 +214,6 @@ class Task(AttributedDocument):
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)