mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 19:06:55 +00:00
283 lines
9.0 KiB
Python
283 lines
9.0 KiB
Python
from typing import Dict, Sequence
|
|
|
|
from mongoengine import (
|
|
StringField,
|
|
EmbeddedDocumentField,
|
|
EmbeddedDocument,
|
|
DateTimeField,
|
|
IntField,
|
|
ListField,
|
|
LongField,
|
|
)
|
|
|
|
from apiserver.database import Database, strict
|
|
from apiserver.database.fields import (
|
|
StrippedStringField,
|
|
SafeMapField,
|
|
SafeDictField,
|
|
UnionField,
|
|
SafeSortedListField,
|
|
EmbeddedDocumentListField,
|
|
NullableStringField,
|
|
)
|
|
from apiserver.database.model import AttributedDocument
|
|
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
|
from apiserver.database.model.model_labels import ModelLabels
|
|
from apiserver.database.model.project import Project
|
|
from apiserver.database.utils import get_options
|
|
from .metrics import MetricEvent, MetricEventStats
|
|
from .output import Output
|
|
|
|
DEFAULT_LAST_ITERATION = 0
|
|
|
|
|
|
class TaskStatus(object):
|
|
created = "created"
|
|
queued = "queued"
|
|
in_progress = "in_progress"
|
|
stopped = "stopped"
|
|
publishing = "publishing"
|
|
published = "published"
|
|
closed = "closed"
|
|
failed = "failed"
|
|
completed = "completed"
|
|
unknown = "unknown"
|
|
|
|
|
|
class TaskStatusMessage(object):
|
|
stopping = "stopping"
|
|
|
|
|
|
class TaskSystemTags(object):
|
|
development = "development"
|
|
|
|
|
|
class Script(EmbeddedDocument, ProperDictMixin):
|
|
binary = StringField(default="python", strip=True)
|
|
repository = StringField(default="", strip=True)
|
|
tag = StringField(strip=True)
|
|
branch = StringField(strip=True)
|
|
version_num = StringField(strip=True)
|
|
entry_point = StringField(default="", strip=True)
|
|
working_dir = StringField(strip=True)
|
|
requirements = SafeDictField()
|
|
diff = StringField()
|
|
|
|
|
|
class ArtifactTypeData(EmbeddedDocument):
|
|
preview = StringField()
|
|
content_type = StringField()
|
|
data_hash = StringField()
|
|
|
|
|
|
class ArtifactModes:
|
|
input = "input"
|
|
output = "output"
|
|
|
|
|
|
DEFAULT_ARTIFACT_MODE = ArtifactModes.output
|
|
|
|
|
|
class Artifact(EmbeddedDocument):
|
|
key = StringField(required=True)
|
|
type = StringField(required=True)
|
|
mode = StringField(
|
|
choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE
|
|
)
|
|
uri = StringField()
|
|
hash = StringField()
|
|
content_size = LongField()
|
|
timestamp = LongField()
|
|
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
|
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
|
|
|
|
|
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
|
section = StringField(required=True)
|
|
name = StringField(required=True)
|
|
value = StringField(required=True)
|
|
type = StringField()
|
|
description = StringField()
|
|
|
|
|
|
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
|
name = StringField(required=True)
|
|
value = StringField(required=True)
|
|
type = StringField()
|
|
description = StringField()
|
|
|
|
|
|
class TaskModelTypes:
|
|
input = "input"
|
|
output = "output"
|
|
|
|
|
|
TaskModelNames = {
|
|
TaskModelTypes.input: "Input Model",
|
|
TaskModelTypes.output: "Output Model",
|
|
}
|
|
|
|
|
|
class ModelItem(EmbeddedDocument, ProperDictMixin):
|
|
name = StringField(required=True)
|
|
model = StringField(required=True, reference_field="Model")
|
|
updated = DateTimeField()
|
|
|
|
|
|
class Models(EmbeddedDocument, ProperDictMixin):
|
|
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
|
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
|
|
|
|
|
class Execution(EmbeddedDocument, ProperDictMixin):
|
|
meta = {"strict": strict}
|
|
test_split = IntField(default=0)
|
|
parameters = SafeDictField(default=dict)
|
|
model_desc = SafeMapField(StringField(default=""))
|
|
model_labels = ModelLabels()
|
|
framework = StringField()
|
|
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
|
|
queue = StringField(reference_field="Queue")
|
|
""" Queue ID where task was queued """
|
|
|
|
|
|
class TaskType(object):
|
|
training = "training"
|
|
testing = "testing"
|
|
inference = "inference"
|
|
data_processing = "data_processing"
|
|
application = "application"
|
|
monitor = "monitor"
|
|
controller = "controller"
|
|
optimizer = "optimizer"
|
|
service = "service"
|
|
qc = "qc"
|
|
custom = "custom"
|
|
|
|
|
|
external_task_types = set(get_options(TaskType))
|
|
|
|
|
|
class Task(AttributedDocument):
|
|
_field_collation_overrides = {
|
|
"execution.parameters.": AttributedDocument._numeric_locale,
|
|
"last_metrics.": AttributedDocument._numeric_locale,
|
|
"hyperparams.": AttributedDocument._numeric_locale,
|
|
}
|
|
|
|
meta = {
|
|
"db_alias": Database.backend,
|
|
"strict": strict,
|
|
"indexes": [
|
|
"created",
|
|
"started",
|
|
"completed",
|
|
"active_duration",
|
|
"parent",
|
|
"project",
|
|
"last_update",
|
|
"status_changed",
|
|
"models.input.model",
|
|
("company", "name"),
|
|
("company", "user"),
|
|
("company", "status", "type"),
|
|
("company", "system_tags", "last_update"),
|
|
("company", "type", "system_tags", "status"),
|
|
("company", "project", "type", "system_tags", "status"),
|
|
("status", "last_update"), # for maintenance tasks
|
|
{
|
|
"fields": ["company", "project"],
|
|
"collation": AttributedDocument._numeric_locale,
|
|
},
|
|
{
|
|
"name": "%s.task.main_text_index" % Database.backend,
|
|
"fields": [
|
|
"$name",
|
|
"$id",
|
|
"$comment",
|
|
"$models.input.model",
|
|
"$models.output.model",
|
|
"$script.repository",
|
|
"$script.entry_point",
|
|
],
|
|
"default_language": "english",
|
|
"weights": {
|
|
"name": 10,
|
|
"id": 10,
|
|
"comment": 10,
|
|
"models.output.model": 2,
|
|
"models.input.model": 2,
|
|
"script.repository": 1,
|
|
"script.entry_point": 1,
|
|
},
|
|
},
|
|
],
|
|
}
|
|
get_all_query_options = GetMixin.QueryParameterOptions(
|
|
list_fields=(
|
|
"id",
|
|
"user",
|
|
"tags",
|
|
"system_tags",
|
|
"type",
|
|
"status",
|
|
"project",
|
|
"parent",
|
|
"hyperparams.*",
|
|
),
|
|
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
|
datetime_fields=("status_changed", "last_update"),
|
|
pattern_fields=("name", "comment"),
|
|
)
|
|
|
|
id = StringField(primary_key=True)
|
|
name = StrippedStringField(
|
|
required=True, user_set_allowed=True, sparse=False, min_length=3
|
|
)
|
|
|
|
type = StringField(required=True, choices=get_options(TaskType))
|
|
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
|
status_reason = StringField()
|
|
status_message = StringField(user_set_allowed=True)
|
|
status_changed = DateTimeField()
|
|
comment = StringField(user_set_allowed=True)
|
|
created = DateTimeField(required=True, user_set_allowed=True)
|
|
started = DateTimeField()
|
|
completed = DateTimeField()
|
|
published = DateTimeField()
|
|
active_duration = IntField(default=None)
|
|
parent = StringField(reference_field="Task")
|
|
project = StringField(reference_field=Project, user_set_allowed=True)
|
|
output: Output = EmbeddedDocumentField(Output, default=Output)
|
|
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
|
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
|
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
|
script: Script = EmbeddedDocumentField(Script, default=Script)
|
|
last_worker = StringField()
|
|
last_worker_report = DateTimeField()
|
|
last_update = DateTimeField()
|
|
last_change = DateTimeField()
|
|
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
|
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
|
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
|
company_origin = StringField(exclude_by_default=True)
|
|
duration = IntField() # task duration in seconds
|
|
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
|
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
|
runtime = SafeDictField(default=dict)
|
|
models: Models = EmbeddedDocumentField(Models, default=Models)
|
|
container = SafeMapField(field=NullableStringField())
|
|
enqueue_status = StringField(
|
|
choices=get_options(TaskStatus), exclude_by_default=True
|
|
)
|
|
|
|
def get_index_company(self) -> str:
|
|
"""
|
|
Returns the company ID used for locating indices containing task data.
|
|
In case the task has a valid company, this is the company ID.
|
|
Otherwise, if the task has a company_origin, this is a task that has been made public and the
|
|
origin company should be used.
|
|
Otherwise, an empty company is used.
|
|
"""
|
|
return self.company or self.company_origin or ""
|