Add multi-models support

This commit is contained in:
allegroai
2021-05-03 17:46:00 +03:00
parent 3c5195028e
commit ef42d0265d
23 changed files with 690 additions and 113 deletions

View File

@@ -11,6 +11,5 @@ class Result(object):
class Output(EmbeddedDocument):
destination = StrippedStringField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Sequence
from mongoengine import (
StringField,
@@ -17,6 +17,7 @@ from apiserver.database.fields import (
SafeDictField,
UnionField,
SafeSortedListField,
EmbeddedDocumentListField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@@ -105,11 +106,21 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
description = StringField()
class ModelItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
model = StringField(required=True, reference_field="Model")
updated = DateTimeField()
class Models(EmbeddedDocument, ProperDictMixin):
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
model_desc = SafeMapField(StringField(default=""))
model_labels = ModelLabels()
framework = StringField()
@@ -155,7 +166,7 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"execution.model",
"models.input.model",
("company", "name"),
("company", "user"),
("company", "status", "type"),
@@ -169,8 +180,8 @@ class Task(AttributedDocument):
"$name",
"$id",
"$comment",
"$execution.model",
"$output.model",
"$models.input.model",
"$models.output.model",
"$script.repository",
"$script.entry_point",
],
@@ -179,8 +190,8 @@ class Task(AttributedDocument):
"name": 10,
"id": 10,
"comment": 10,
"execution.model": 2,
"output.model": 2,
"models.output.model": 2,
"models.input.model": 2,
"script.repository": 1,
"script.entry_point": 1,
},
@@ -238,6 +249,7 @@ class Task(AttributedDocument):
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)
models: Models = EmbeddedDocumentField(Models, default=Models)
docker_init_script = StringField()
def get_index_company(self) -> str: