mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2216bfe875 | ||
|
|
9beefa7473 | ||
|
|
8ebc334889 | ||
|
|
e662c850af | ||
|
|
1e5163e530 | ||
|
|
1567774765 | ||
|
|
babfcbb707 | ||
|
|
027edd86bb | ||
|
|
cc83aadae6 | ||
|
|
8c18660a82 | ||
|
|
4fe61ee25c | ||
|
|
e18b21639c | ||
|
|
1cef03b8c2 | ||
|
|
d60d6dfe99 | ||
|
|
27d086bca2 | ||
|
|
add3f011a0 | ||
|
|
ee90b0b024 | ||
|
|
9bf107866f | ||
|
|
4d2f282950 | ||
|
|
b55fad1b59 | ||
|
|
ba77ff11e9 | ||
|
|
b67aa05d6f | ||
|
|
6b0c45a861 | ||
|
|
dc9623e964 | ||
|
|
3d73d60826 | ||
|
|
9f0c9c3690 | ||
|
|
1a3d3494ce | ||
|
|
b99f620073 | ||
|
|
e2f265b4bc | ||
|
|
251ee57ffd | ||
|
|
7e03104f1c | ||
|
|
f1a258208e | ||
|
|
66cc49313b | ||
|
|
9ae2943f7d | ||
|
|
54326f707b | ||
|
|
3a3b57c15f | ||
|
|
8ea8ad34e6 | ||
|
|
179661a0d4 | ||
|
|
3d22ca1888 | ||
|
|
fdf6798d0c | ||
|
|
9d9a44b927 | ||
|
|
dad935e81d | ||
|
|
a75534ec34 | ||
|
|
eab33de97e | ||
|
|
29de110abb | ||
|
|
2e7f418ee2 | ||
|
|
dadb996d22 | ||
|
|
174f692edf | ||
|
|
f4d5168a20 | ||
|
|
5a438e8435 | ||
|
|
ce4814dc47 | ||
|
|
ef42d0265d | ||
|
|
3c5195028e | ||
|
|
0d5174c453 | ||
|
|
c034c1a986 | ||
|
|
1b49da8748 | ||
|
|
26bda01a28 | ||
|
|
f5008d80ad | ||
|
|
8b464e7ae6 | ||
|
|
78e4a58c91 | ||
|
|
7a4a5eb03e | ||
|
|
d029d56508 | ||
|
|
6411954002 | ||
|
|
7f4ad0d1ca | ||
|
|
4cd4b2914d | ||
|
|
1d55710a0b | ||
|
|
8f646043bb | ||
|
|
4b11a6efcd | ||
|
|
cb3a7c90a8 | ||
|
|
074842a122 | ||
|
|
749ff4a44f | ||
|
|
7d6918ecb0 | ||
|
|
47184c2833 | ||
|
|
6434f1028e | ||
|
|
daade08940 | ||
|
|
a1d289822f | ||
|
|
1ce34f2c74 | ||
|
|
c2dc73a71f | ||
|
|
07bb3b5df8 | ||
|
|
067ef82576 | ||
|
|
59fc98e0c4 |
20
README.md
20
README.md
@@ -20,7 +20,7 @@
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
|
||||
Follow [this procedure](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_es7_migration.html) to migrate existing data.
|
||||
|
||||
---
|
||||
|
||||
@@ -78,15 +78,15 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch The **ClearML Server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
|
||||
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_aws_ec2_ami.html)
|
||||
- Pre-built [GCP Custom Image](hhttps://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_gcp.html)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
|
||||
- [Linux](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [macOS](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [Windows 10](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_win.html)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
|
||||
- [Kubernetes Helm](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes_helm.html)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes.html)
|
||||
|
||||
## Connecting ClearML to your ClearML Server
|
||||
|
||||
@@ -211,9 +211,9 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
|
||||
|
||||
Additionally, you can always find us at *clearml@allegro.ai*
|
||||
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
301 {
|
||||
_: "moved_permanently"
|
||||
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
|
||||
}
|
||||
|
||||
400 {
|
||||
_: "bad_request"
|
||||
1: ["not_supported", "endpoint is not supported"]
|
||||
@@ -62,6 +67,10 @@
|
||||
402: ["project_has_tasks", "project has associated tasks"]
|
||||
403: ["project_not_found", "project not found"]
|
||||
405: ["project_has_models", "project has associated models"]
|
||||
407: ["invalid_project_name", "invalid project name"]
|
||||
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
|
||||
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
|
||||
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
@@ -108,6 +117,11 @@
|
||||
21: ["no_write_permission", "forbidden (modification not allowed)"]
|
||||
}
|
||||
|
||||
410: {
|
||||
_: "gone"
|
||||
1: ["not_supported", "thus endpoint is not supported any more"]
|
||||
}
|
||||
|
||||
500 {
|
||||
_: "server_error"
|
||||
0: ["general_error", "general server error"]
|
||||
|
||||
@@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if value is None and not self.required:
|
||||
if value is NotSet and not self.required:
|
||||
return self.get_default_value()
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
|
||||
25
apiserver/apimodels/batch.py
Normal file
25
apiserver/apimodels/batch.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
|
||||
|
||||
class BatchRequest(Base):
|
||||
ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class BatchResponse(Base):
|
||||
succeeded: Sequence[dict] = ListField([dict])
|
||||
failed: Sequence[dict] = ListField([dict])
|
||||
|
||||
|
||||
class UpdateBatchItem(UpdateResponse):
|
||||
id: str = StringField()
|
||||
|
||||
|
||||
class UpdateBatchResponse(BatchResponse):
|
||||
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)
|
||||
@@ -38,7 +38,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
@@ -89,7 +89,6 @@ class IterationEvents(Base):
|
||||
|
||||
class MetricEvents(Base):
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||
|
||||
|
||||
|
||||
@@ -31,3 +31,4 @@ class GetSupportedModesResponse(Base):
|
||||
server_errors = EmbeddedField(ServerErrors)
|
||||
sso = DictField([str, type(None)])
|
||||
sso_providers = ListField([dict])
|
||||
authenticated = BoolField(default=False)
|
||||
|
||||
23
apiserver/apimodels/metadata.py
Normal file
23
apiserver/apimodels/metadata.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class MetadataItem(Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
|
||||
|
||||
class DeleteMetadata(Base):
|
||||
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
|
||||
|
||||
|
||||
class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
@@ -3,7 +3,12 @@ from six import string_types
|
||||
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
from apiserver.apimodels.batch import BatchRequest
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
@@ -13,7 +18,7 @@ class GetFrameworksRequest(models.Base):
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
labels = DictField(value_types=string_types+(int,))
|
||||
labels = DictField(value_types=string_types + (int,))
|
||||
tags = ListField(items_types=string_types)
|
||||
system_tags = ListField(items_types=string_types)
|
||||
comment = fields.StringField()
|
||||
@@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
@@ -32,17 +38,40 @@ class CreateModelResponse(models.Base):
|
||||
created = fields.BoolField(required=True)
|
||||
|
||||
|
||||
class PublishModelRequest(models.Base):
|
||||
class ModelRequest(models.Base):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class DeleteModelRequest(ModelRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class ModelsDeleteManyRequest(BatchRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class PublishModelRequest(ModelRequest):
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ModelTaskPublishResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
data = fields.EmbeddedField(TaskPublishResponse)
|
||||
data = fields.EmbeddedField(UpdateResponse)
|
||||
|
||||
|
||||
class PublishModelResponse(UpdateResponse):
|
||||
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
|
||||
updated = fields.IntField()
|
||||
|
||||
|
||||
class ModelsPublishManyRequest(BatchRequest):
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
@@ -5,11 +5,29 @@ from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
class ProjectRequest(models.Base):
|
||||
project = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MergeRequest(ProjectRequest):
|
||||
destination_project = fields.StringField()
|
||||
|
||||
|
||||
class MoveRequest(ProjectRequest):
|
||||
new_location = fields.StringField()
|
||||
|
||||
|
||||
class DeleteRequest(ProjectRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_contents = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class ProjectOrNoneRequest(models.Base):
|
||||
project = fields.StringField()
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetHyperParamReq(ProjectReq):
|
||||
class GetHyperParamRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
@@ -18,7 +36,25 @@ class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(ProjectReq):
|
||||
projects = ListField(str)
|
||||
class MultiProjectRequest(models.Base):
|
||||
projects = fields.ListField(str)
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(MultiProjectRequest):
|
||||
tasks_state = ActualEnumField(EntityVisibility)
|
||||
|
||||
|
||||
class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
section = fields.StringField(required=True)
|
||||
name = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
check_own_contents = fields.BoolField(default=False)
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
|
||||
@@ -3,6 +3,11 @@ from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
|
||||
|
||||
class GetDefaultResp(Base):
|
||||
@@ -14,6 +19,7 @@ class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
@@ -28,6 +34,7 @@ class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
@@ -58,3 +65,11 @@ class QueueMetrics(Base):
|
||||
|
||||
class GetMetricsResponse(Base):
|
||||
queues = ListField(QueueMetrics)
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
queue = StringField(required=True)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apiserver.apimodels import DictField, ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskType,
|
||||
ArtifactModes,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
@@ -43,26 +44,54 @@ class EnqueueResponse(UpdateResponse):
|
||||
queued = IntField()
|
||||
|
||||
|
||||
class EnqueueBatchItem(UpdateBatchItem):
|
||||
queued: bool = BoolField()
|
||||
|
||||
|
||||
class EnqueueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
|
||||
|
||||
|
||||
class DequeueResponse(UpdateResponse):
|
||||
dequeued = IntField()
|
||||
|
||||
|
||||
class DequeueBatchItem(UpdateBatchItem):
|
||||
dequeued: bool = BoolField()
|
||||
|
||||
|
||||
class DequeueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
|
||||
|
||||
|
||||
class ResetResponse(UpdateResponse):
|
||||
deleted_indices = ListField(items_types=six.string_types)
|
||||
dequeued = DictField()
|
||||
frames = DictField()
|
||||
events = DictField()
|
||||
model_deleted = IntField()
|
||||
deleted_models = IntField()
|
||||
urls = DictField()
|
||||
|
||||
|
||||
class ResetBatchItem(UpdateBatchItem):
|
||||
dequeued: bool = BoolField()
|
||||
deleted_models = IntField()
|
||||
urls = DictField()
|
||||
|
||||
|
||||
class ResetManyResponse(BatchResponse):
|
||||
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class UpdateRequest(TaskRequest):
|
||||
class TaskUpdateRequest(TaskRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateRequest(TaskUpdateRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
@@ -71,6 +100,8 @@ class EnqueueRequest(UpdateRequest):
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
@@ -81,10 +112,6 @@ class PublishRequest(UpdateRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishResponse(UpdateResponse):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(models.Base):
|
||||
"""
|
||||
This is a partial description of task can be updated incrementally
|
||||
@@ -104,6 +131,11 @@ class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class TaskInputModel(models.Base):
|
||||
name = StringField()
|
||||
model = StringField()
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
@@ -113,14 +145,15 @@ class CloneRequest(TaskRequest):
|
||||
new_task_project = StringField()
|
||||
new_task_hyperparams = DictField()
|
||||
new_task_configuration = DictField()
|
||||
new_task_container = DictField()
|
||||
new_task_input_models = ListField([TaskInputModel])
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
new_project_name = StringField()
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArtifactId(models.Base):
|
||||
@@ -130,13 +163,14 @@ class ArtifactId(models.Base):
|
||||
)
|
||||
|
||||
|
||||
class DeleteArtifactsRequest(TaskRequest):
|
||||
class DeleteArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
@@ -161,7 +195,7 @@ class ReplaceHyperparams(object):
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
class EditHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
@@ -169,7 +203,6 @@ class EditHyperParamsRequest(TaskRequest):
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
@@ -177,11 +210,10 @@ class HyperParamKey(models.Base):
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
class DeleteHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
@@ -189,7 +221,7 @@ class GetConfigurationsRequest(MultiTaskRequest):
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
skip_empty = BoolField(default=True)
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
@@ -199,17 +231,15 @@ class Configuration(models.Base):
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
class EditConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
class DeleteConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArchiveRequest(MultiTaskRequest):
|
||||
@@ -219,3 +249,54 @@ class ArchiveRequest(MultiTaskRequest):
|
||||
|
||||
class ArchiveResponse(models.Base):
|
||||
archived = IntField()
|
||||
|
||||
|
||||
class TaskBatchRequest(BatchRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
|
||||
|
||||
class StopManyRequest(TaskBatchRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class EnqueueManyRequest(TaskBatchRequest):
|
||||
queue = StringField()
|
||||
validate_tasks = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteManyRequest(TaskBatchRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ResetManyRequest(TaskBatchRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class PublishManyRequest(TaskBatchRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class AddUpdateModelRequest(TaskRequest):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
iteration = IntField()
|
||||
|
||||
|
||||
class ModelItemKey(models.Base):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
|
||||
|
||||
class DeleteModelsRequest(TaskRequest):
|
||||
models: Sequence[ModelItemKey] = ListField(
|
||||
[ModelItemKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
@@ -28,19 +26,22 @@ from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
name: str = StringField(required=True)
|
||||
recycle_url_marker: str = StringField()
|
||||
class VariantState(Base):
|
||||
variant: str = StringField(required=True)
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricScrollState(Base):
|
||||
class MetricState(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[VariantState] = ListField([VariantState], required=True)
|
||||
timestamp: int = IntField(default=0)
|
||||
|
||||
|
||||
class TaskScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
name: str = StringField(required=True)
|
||||
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
timestamp: int = IntField(default=0)
|
||||
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
@@ -49,7 +50,7 @@ class MetricScrollState(Base):
|
||||
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@@ -73,7 +74,7 @@ class DebugImagesIterator:
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
@@ -83,8 +84,7 @@ class DebugImagesIterator:
|
||||
return DebugImagesResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
|
||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
@@ -92,16 +92,8 @@ class DebugImagesIterator:
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
self._reinit_outdated_task_states(company_id, state_, task_metrics)
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
@@ -116,101 +108,124 @@ class DebugImagesIterator:
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
),
|
||||
state.metrics,
|
||||
state.tasks,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, state: DebugImageEventsScrollState
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
company_id,
|
||||
state: DebugImageEventsScrollState,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
):
|
||||
"""
|
||||
Determines the metrics for which new debug image events were added
|
||||
since their states were initialized and reinits these states
|
||||
Determine the metrics for which new debug image events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
task_ids = set(metric.task for metric in state.metrics)
|
||||
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
|
||||
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
) -> Mapping[str, datetime]:
|
||||
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return []
|
||||
return {}
|
||||
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
|
||||
)
|
||||
requested_metrics = task_metrics[task.id]
|
||||
return {
|
||||
stats.metric: stats.event_stats_by_type[
|
||||
self.EVENT_TYPE.value
|
||||
].last_update
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||
]
|
||||
and (not requested_metrics or stats.metric in requested_metrics)
|
||||
}
|
||||
|
||||
update_times = dict(
|
||||
chain.from_iterable(
|
||||
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
update_times = {
|
||||
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
}
|
||||
task_metric_states = {
|
||||
task_state.task: {
|
||||
metric_state.metric: metric_state for metric_state in task_state.metrics
|
||||
}
|
||||
for task_state in state.tasks
|
||||
}
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = set(
|
||||
m
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
)
|
||||
)
|
||||
outdated_metrics = [
|
||||
metric
|
||||
for metric in state.metrics
|
||||
if (metric.task, metric.name) in update_times
|
||||
and update_times[metric.task, metric.name] > metric.timestamp
|
||||
]
|
||||
state.metrics = [
|
||||
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
||||
*(
|
||||
self._init_metric_states(
|
||||
company_id,
|
||||
[(metric.task, metric.name) for metric in outdated_metrics],
|
||||
)
|
||||
),
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
|
||||
|
||||
def merge_with_updated_task_states(
|
||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||
) -> TaskScrollState:
|
||||
task = old_state.task
|
||||
updated_state = first(uts for uts in updates if uts.task == task)
|
||||
if not updated_state:
|
||||
old_state.reset()
|
||||
return old_state
|
||||
|
||||
updated_metrics = [m.metric for m in updated_state.metrics]
|
||||
return TaskScrollState(
|
||||
task=task,
|
||||
metrics=[
|
||||
*updated_state.metrics,
|
||||
*(
|
||||
old_metric
|
||||
for old_metric in old_state.metrics
|
||||
if old_metric.metric not in updated_metrics
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
state.tasks = [
|
||||
merge_with_updated_task_states(task_state, updated_task_states)
|
||||
for task_state in state.tasks
|
||||
]
|
||||
|
||||
def _init_metric_states(
|
||||
self, company_id: str, metrics: Sequence[Tuple[str, str]]
|
||||
) -> Sequence[MetricScrollState]:
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
tasks = defaultdict(list)
|
||||
for (task, metric) in metrics:
|
||||
tasks[task].append(metric)
|
||||
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
self._init_metric_states_for_task, company_id=company_id
|
||||
),
|
||||
tasks.items(),
|
||||
)
|
||||
)
|
||||
task_metric_states = pool.map(
|
||||
partial(self._init_metric_states_for_task, company_id=company_id),
|
||||
task_metrics.items(),
|
||||
)
|
||||
|
||||
return [
|
||||
TaskScrollState(task=task, metrics=metric_states,)
|
||||
for task, metric_states in zip(task_metrics, task_metric_states)
|
||||
]
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
|
||||
) -> Sequence[MetricScrollState]:
|
||||
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any debug images
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
|
||||
if metrics:
|
||||
must.append({"terms": {"metric": list(metrics)}})
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@@ -254,20 +269,17 @@ class DebugImagesIterator:
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_scroll_state(variant: dict):
|
||||
def init_variant_state(variant: dict):
|
||||
"""
|
||||
Return new variant scroll state for the passed variant bucket
|
||||
Return new variant state for the passed variant bucket
|
||||
If the image urls get recycled then fill the last_invalid_iteration field
|
||||
"""
|
||||
state = VariantScrollState(name=variant["key"])
|
||||
state = VariantState(variant=variant["key"])
|
||||
top_iter_url = dpath.get(variant, "urls/buckets")[0]
|
||||
iters = dpath.get(top_iter_url, "iters/hits/hits")
|
||||
if len(iters) > 1:
|
||||
@@ -275,102 +287,52 @@ class DebugImagesIterator:
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricScrollState(
|
||||
task=task,
|
||||
name=metric["key"],
|
||||
MetricState(
|
||||
metric=metric["key"],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
variants=[
|
||||
init_variant_scroll_state(variant)
|
||||
init_variant_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
metric: MetricScrollState,
|
||||
task_state: TaskScrollState,
|
||||
company_id: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update metric scroll state
|
||||
Update task scroll state
|
||||
"""
|
||||
if metric.last_max_iter is None:
|
||||
if not task_state.metrics:
|
||||
return task_state.task, []
|
||||
|
||||
if task_state.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"term": {"task": task_state.task}},
|
||||
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and metric.last_min_iter is not None:
|
||||
range_condition = {"lt": metric.last_min_iter}
|
||||
elif not navigate_earlier and metric.last_max_iter is not None:
|
||||
range_condition = {"gt": metric.last_max_iter}
|
||||
if navigate_earlier and task_state.last_min_iter is not None:
|
||||
range_condition = {"lt": task_state.last_min_iter}
|
||||
elif not navigate_earlier and task_state.last_max_iter is not None:
|
||||
range_condition = {"gt": task_state.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
if navigate_earlier:
|
||||
"""
|
||||
When navigating to earlier iterations consider only
|
||||
variants whose invalid iterations border is lower than
|
||||
our starting iteration. For these variants make sure
|
||||
that only events from the valid iterations are returned
|
||||
"""
|
||||
if not metric.last_min_iter:
|
||||
variants = metric.variants
|
||||
else:
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is None
|
||||
or v.last_invalid_iteration < metric.last_min_iter
|
||||
)
|
||||
if not variants:
|
||||
return metric.task, metric.name, []
|
||||
must_conditions.append(
|
||||
{"terms": {"variant": list(v.name for v in variants)}}
|
||||
)
|
||||
else:
|
||||
"""
|
||||
When navigating to later iterations all variants may be relevant.
|
||||
For the variants whose invalid border is higher than our starting
|
||||
iteration make sure that only events from valid iterations are returned
|
||||
"""
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is not None
|
||||
and v.last_invalid_iteration > metric.last_max_iter
|
||||
)
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
if v.last_invalid_iteration is not None
|
||||
]
|
||||
if variants_conditions:
|
||||
must_not_conditions.append({"bool": {"should": variants_conditions}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {"must": must_conditions, "must_not": must_not_conditions}
|
||||
},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
@@ -379,15 +341,26 @@ class DebugImagesIterator:
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {"sort": {"url": {"order": "desc"}}}
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"url": {"order": "desc"}}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -397,80 +370,44 @@ class DebugImagesIterator:
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
return task_state.task, []
|
||||
|
||||
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
|
||||
invalid_iterations = {
|
||||
(m.metric, v.variant): v.last_invalid_iteration
|
||||
for m in task_state.metrics
|
||||
for v in m.variants
|
||||
}
|
||||
|
||||
def is_valid_event(event: dict) -> bool:
|
||||
key = event.get("metric"), event.get("variant")
|
||||
if key not in invalid_iterations:
|
||||
return False
|
||||
|
||||
max_invalid = invalid_iterations[key]
|
||||
return max_invalid is None or event.get("iter") > max_invalid
|
||||
|
||||
def get_iteration_events(it_: dict) -> Sequence:
|
||||
return [
|
||||
ev["_source"]
|
||||
for v in variant_buckets
|
||||
for m in dpath.get(it_, "metrics/buckets")
|
||||
for v in dpath.get(m, "variants/buckets")
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
if is_valid_event(ev["_source"])
|
||||
]
|
||||
|
||||
iterations = [
|
||||
{
|
||||
"iter": it["key"],
|
||||
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
|
||||
}
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets")
|
||||
]
|
||||
iterations = []
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets"):
|
||||
events = get_iteration_events(it)
|
||||
if events:
|
||||
iterations.append({"iter": it["key"], "events": events})
|
||||
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
metric.last_max_iter = iterations[0]["iter"]
|
||||
metric.last_min_iter = iterations[-1]["iter"]
|
||||
task_state.last_max_iter = iterations[0]["iter"]
|
||||
task_state.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
# Commented for now since the last invalid iteration is calculated in the beginning
|
||||
# if navigate_earlier and any(
|
||||
# variant.last_invalid_iteration is None for variant in variants
|
||||
# ):
|
||||
# """
|
||||
# Variants validation flags due to recycling can
|
||||
# be set only on navigation to earlier frames
|
||||
# """
|
||||
# iterations = self._update_variants_invalid_iterations(variants, iterations)
|
||||
|
||||
return metric.task, metric.name, iterations
|
||||
|
||||
@staticmethod
|
||||
def _update_variants_invalid_iterations(
|
||||
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
This code is currently not in used since the invalid iterations
|
||||
are calculated during MetricState initialization
|
||||
For variants that do not have recycle url marker set it from the
|
||||
first event
|
||||
For variants that do not have last_invalid_iteration set check if the
|
||||
recycle marker was reached on a certain iteration and set it to the
|
||||
corresponding iteration
|
||||
For variants that have a newly set last_invalid_iteration remove
|
||||
events from the invalid iterations
|
||||
Return the updated iterations list
|
||||
"""
|
||||
variants_lookup = bucketize(variants, attrgetter("name"))
|
||||
for it in iterations:
|
||||
iteration = it["iter"]
|
||||
events_to_remove = []
|
||||
for event in it["events"]:
|
||||
variant = variants_lookup[event["variant"]][0]
|
||||
if (
|
||||
variant.last_invalid_iteration
|
||||
and variant.last_invalid_iteration >= iteration
|
||||
):
|
||||
events_to_remove.append(event)
|
||||
continue
|
||||
event_url = event.get("url")
|
||||
if not variant.recycle_url_marker:
|
||||
variant.recycle_url_marker = event_url
|
||||
elif variant.recycle_url_marker == event_url:
|
||||
variant.last_invalid_iteration = iteration
|
||||
events_to_remove.append(event)
|
||||
if events_to_remove:
|
||||
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
||||
return [it for it in iterations if it["events"]]
|
||||
return task_state.task, iterations
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
import zlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
@@ -9,6 +10,7 @@ from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
@@ -36,12 +38,13 @@ from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
MAX_LONG = 2**63 - 1
|
||||
MIN_LONG = -2**63
|
||||
|
||||
|
||||
class PlotFields:
|
||||
@@ -49,11 +52,16 @@ class PlotFields:
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
source_urls = "source_urls"
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
img_source_regex = re.compile(
|
||||
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@@ -96,6 +104,7 @@ class EventBLL(object):
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
errors_per_type = defaultdict(int)
|
||||
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
@@ -145,6 +154,9 @@ class EventBLL(object):
|
||||
iter = event.get("iter")
|
||||
if iter is not None:
|
||||
iter = int(iter)
|
||||
if iter > MAX_LONG or iter < MIN_LONG:
|
||||
errors_per_type[invalid_iteration_error] += 1
|
||||
continue
|
||||
event["iter"] = iter
|
||||
|
||||
# used to have "values" to indicate array. no need anymore
|
||||
@@ -201,47 +213,55 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += 1
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
with translate_errors_context():
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += 1
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# this is for backwards compatibility with streaming bulk throwing exception on those
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
if invalid_iterations_count:
|
||||
raise BulkIndexError(
|
||||
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
|
||||
)
|
||||
|
||||
if not added:
|
||||
raise errors.bad_request.EventsNotAdded(**errors_per_type)
|
||||
@@ -269,6 +289,11 @@ class EventBLL(object):
|
||||
event[PlotFields.plot_len] = plot_len
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
|
||||
|
||||
urls = {match for match in self.img_source_regex.findall(plot_str)}
|
||||
if urls:
|
||||
event[PlotFields.source_urls] = list(urls)
|
||||
|
||||
if compression_threshold and plot_len >= compression_threshold:
|
||||
event[PlotFields.plot_data] = base64.encodebytes(
|
||||
zlib.compress(plot_str.encode(), level=1)
|
||||
@@ -430,6 +455,9 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
if event_type in (EventType.metrics_plot, EventType.all):
|
||||
self.uncompress_plots(events)
|
||||
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
@@ -501,7 +529,7 @@ class EventBLL(object):
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
@@ -595,6 +623,41 @@ class EventBLL(object):
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_plot_image_urls(
|
||||
self, company_id: str, task_id: str, scroll_id: Optional[str]
|
||||
) -> Tuple[Sequence[dict], Optional[str]]:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], None
|
||||
|
||||
if scroll_id:
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="10m")
|
||||
else:
|
||||
if check_empty_data(self.es, company_id, EventType.metrics_plot):
|
||||
return [], None
|
||||
|
||||
es_req = {
|
||||
"size": 1000,
|
||||
"_source": [PlotFields.source_urls],
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"exists": {"field": PlotFields.source_urls}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.metrics_plot,
|
||||
body=es_req,
|
||||
scroll="10m",
|
||||
)
|
||||
|
||||
events, _, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
@@ -672,6 +735,9 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
if event_type in (EventType.metrics_plot, EventType.all):
|
||||
self.uncompress_plots(events)
|
||||
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
@@ -892,3 +958,20 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
|
||||
"""
|
||||
Delete mutliple task events. No check is done for tasks write access
|
||||
so it should be checked by the calling code
|
||||
"""
|
||||
es_req = {"query": {"terms": {"task": task_ids}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
@@ -4,8 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
from apiserver.bll.util import extract_properties_to_lists
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -1,18 +1,129 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
from datetime import datetime
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.utils import get_company_or_none_constraint
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
@classmethod
|
||||
def get_company_model_by_id(
|
||||
cls, company_id: str, model_id: str, only_fields=None
|
||||
) -> Model:
|
||||
query = dict(company=company_id, id=model_id)
|
||||
qs = Model.objects(**query)
|
||||
if only_fields:
|
||||
qs = qs.only(*only_fields)
|
||||
model = qs.first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def publish_model(
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
force_publish_task: bool = False,
|
||||
publish_task_func: Callable[[str, str, bool], dict] = None,
|
||||
) -> Tuple[int, ModelTaskPublishResponse]:
|
||||
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
if model.ready:
|
||||
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
||||
|
||||
published_task = None
|
||||
if model.task and publish_task_func:
|
||||
task = (
|
||||
Task.objects(id=model.task, company=company_id)
|
||||
.only("id", "status")
|
||||
.first()
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
task_publish_res = publish_task_func(
|
||||
model.task, company_id, force_publish_task
|
||||
)
|
||||
published_task = ModelTaskPublishResponse(
|
||||
id=model.task, data=task_publish_res
|
||||
)
|
||||
|
||||
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
|
||||
return updated, published_task
|
||||
|
||||
@classmethod
|
||||
def delete_model(
|
||||
cls, model_id: str, company_id: str, force: bool
|
||||
) -> Tuple[int, Model]:
|
||||
model = cls.get_company_model_by_id(
|
||||
company_id=company_id,
|
||||
model_id=model_id,
|
||||
only_fields=("id", "task", "project", "uri"),
|
||||
)
|
||||
deleted_model_id = f"{deleted_prefix}{model_id}"
|
||||
|
||||
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
||||
if using_tasks:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelInUse(
|
||||
"as execution model, use force=True to delete",
|
||||
num_tasks=len(using_tasks),
|
||||
)
|
||||
# update deleted model id in using tasks
|
||||
Task._get_collection().update_many(
|
||||
filter={"_id": {"$in": [t.id for t in using_tasks]}},
|
||||
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
if model.task:
|
||||
task = Task.objects(id=model.task).first()
|
||||
if task and task.status == TaskStatus.published:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelCreatingTaskExists(
|
||||
"and published, use force=True to delete", task=model.task
|
||||
)
|
||||
if task.models.output and model_id in task.models.output:
|
||||
now = datetime.utcnow()
|
||||
Task._get_collection().update_one(
|
||||
filter={"_id": model.task, "models.output.model": model_id},
|
||||
update={
|
||||
"$set": {
|
||||
"models.output.$[elem].model": deleted_model_id,
|
||||
"output.error": f"model deleted on {now.isoformat()}",
|
||||
},
|
||||
"last_change": now,
|
||||
},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
del_count = Model.objects(id=model_id, company=company_id).delete()
|
||||
return del_count, model
|
||||
|
||||
@classmethod
|
||||
def archive_model(cls, model_id: str, company_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
archived = Model.objects(company=company_id, id=model_id).update(
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
|
||||
return archived
|
||||
|
||||
@classmethod
|
||||
def unarchive_model(cls, model_id: str, company_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
unarchived = Model.objects(company=company_id, id=model_id).update(
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
|
||||
return unarchived
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
@@ -65,34 +61,3 @@ class OrgBLL:
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
|
||||
@classmethod
|
||||
def get_parent_tasks(
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get list of unique parent tasks sorted by task name for the passed company projects
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
if projects:
|
||||
query &= Q(project__in=projects)
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
|
||||
parent_ids = set(Task.objects(query).distinct("parent"))
|
||||
if not parent_ids:
|
||||
return []
|
||||
|
||||
parents = Task.get_many_with_join(
|
||||
company_id,
|
||||
query=Q(id__in=parent_ids),
|
||||
allow_public=True,
|
||||
override_projection=("id", "name", "project.name"),
|
||||
)
|
||||
return sorted(parents, key=itemgetter("name"))
|
||||
|
||||
@@ -5,6 +5,7 @@ from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
@@ -40,7 +41,7 @@ class _TagsCache:
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
query &= Q(project__in=project_ids_with_children([project]))
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
|
||||
@@ -1,40 +1,187 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Sequence, Optional, Type
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
Dict,
|
||||
Set,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
from mongoengine import Q, Document
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility, AttributedDocument
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
from apiserver.database.utils import get_options, get_company_or_none_constraint
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .sub_projects import (
|
||||
_reposition_project_with_children,
|
||||
_ensure_project,
|
||||
_validate_project_name,
|
||||
_update_subproject_names,
|
||||
_save_under_parent,
|
||||
_get_sub_projects,
|
||||
_ids_with_children,
|
||||
_ids_with_parents,
|
||||
_get_project_depth,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
def merge_project(
|
||||
cls, company, source_id: str, destination_id: str
|
||||
) -> Tuple[int, int, Set[str]]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
Move all the tasks and sub projects from the source project to the destination
|
||||
Remove the source project
|
||||
Return the amounts of moved entities and subprojects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
with TimingContext("mongo", "move_project"):
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
parent=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
destination = Project.get(company, destination_id)
|
||||
|
||||
return res
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
)[source.id]
|
||||
cls.validate_projects_depth(
|
||||
projects=children,
|
||||
old_parent_depth=len(source.path) + 1,
|
||||
new_parent_depth=len(destination.path) + 1,
|
||||
)
|
||||
|
||||
moved_entities = 0
|
||||
for entity_type in (Task, Model):
|
||||
moved_entities += entity_type.objects(
|
||||
company=company,
|
||||
project=source_id,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).update(upsert=False, project=destination_id)
|
||||
|
||||
moved_sub_projects = 0
|
||||
for child in Project.objects(company=company, parent=source_id):
|
||||
_reposition_project_with_children(
|
||||
project=child,
|
||||
children=[c for c in children if c.parent == child.id],
|
||||
parent=destination,
|
||||
)
|
||||
moved_sub_projects += 1
|
||||
|
||||
affected = {source.id, *(source.path or [])}
|
||||
source.delete()
|
||||
|
||||
if destination:
|
||||
destination.update(last_update=datetime.utcnow())
|
||||
affected.update({destination.id, *(destination.path or [])})
|
||||
|
||||
return moved_entities, moved_sub_projects, affected
|
||||
|
||||
@staticmethod
|
||||
def validate_projects_depth(
|
||||
projects: Sequence[Project], old_parent_depth: int, new_parent_depth: int
|
||||
):
|
||||
for current in projects:
|
||||
current_depth = len(current.path) + 1
|
||||
if current_depth - old_parent_depth + new_parent_depth > max_depth:
|
||||
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
|
||||
|
||||
@classmethod
|
||||
def move_project(
|
||||
cls, company: str, user: str, project_id: str, new_location: str
|
||||
) -> Tuple[int, Set[str]]:
|
||||
"""
|
||||
Move project with its sub projects from its current location to the target one.
|
||||
If the target location does not exist then it will be created. If it exists then
|
||||
it should be writable. The source location should be writable too.
|
||||
Return the number of moved projects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "move_project"):
|
||||
project = Project.get(company, project_id)
|
||||
old_parent_id = project.parent
|
||||
old_parent = (
|
||||
Project.get_for_writing(company=project.company, id=old_parent_id)
|
||||
if old_parent_id
|
||||
else None
|
||||
)
|
||||
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
cls.validate_projects_depth(
|
||||
projects=[project, *children],
|
||||
old_parent_depth=len(project.path),
|
||||
new_parent_depth=_get_project_depth(new_location),
|
||||
)
|
||||
|
||||
new_parent = _ensure_project(company=company, user=user, name=new_location)
|
||||
new_parent_id = new_parent.id if new_parent else None
|
||||
if old_parent_id == new_parent_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
affected = set()
|
||||
for p in filter(None, (old_parent, new_parent)):
|
||||
p.update(last_update=now)
|
||||
affected.update({p.id, *(p.path or [])})
|
||||
|
||||
return moved, affected
|
||||
|
||||
@classmethod
|
||||
def update(cls, company: str, project_id: str, **fields):
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
new_name = fields.pop("name", None)
|
||||
if new_name:
|
||||
new_name, new_location = _validate_project_name(new_name)
|
||||
old_name, old_location = _validate_project_name(project.name)
|
||||
if new_location != old_location:
|
||||
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
|
||||
fields["name"] = new_name
|
||||
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
updated = project.update(upsert=False, **fields)
|
||||
|
||||
if new_name:
|
||||
old_name = project.name
|
||||
project.name = new_name
|
||||
children = _get_sub_projects(
|
||||
[project.id], _only=("id", "name", "path")
|
||||
)[project.id]
|
||||
_update_subproject_names(
|
||||
project=project, children=children, old_name=old_name
|
||||
)
|
||||
|
||||
return updated
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -42,7 +189,7 @@ class ProjectBLL:
|
||||
user: str,
|
||||
company: str,
|
||||
name: str,
|
||||
description: str,
|
||||
description: str = "",
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
@@ -51,6 +198,10 @@ class ProjectBLL:
|
||||
Create a new project.
|
||||
Returns project ID
|
||||
"""
|
||||
if _get_project_depth(name) > max_depth:
|
||||
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
|
||||
|
||||
name, location = _validate_project_name(name)
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
@@ -64,7 +215,11 @@ class ProjectBLL:
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
project.save()
|
||||
parent = _ensure_project(company=company, user=user, name=location)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
return project.id
|
||||
|
||||
@classmethod
|
||||
@@ -92,6 +247,7 @@ class ProjectBLL:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
return project_id
|
||||
|
||||
project_name, _ = _validate_project_name(project_name)
|
||||
project = Project.objects(company=company, name=project_name).only("id").first()
|
||||
if project:
|
||||
return project.id
|
||||
@@ -125,13 +281,439 @@ class ProjectBLL:
|
||||
company=company,
|
||||
project_id=project,
|
||||
project_name=project_name,
|
||||
description="Auto-generated during move",
|
||||
description="",
|
||||
)
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
|
||||
@classmethod
|
||||
def make_projects_get_all_pipelines(
|
||||
cls,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
"""
|
||||
return {
|
||||
"$addFields": {
|
||||
"system_tags": {
|
||||
"$cond": {
|
||||
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
||||
"then": [],
|
||||
"else": "$system_tags",
|
||||
}
|
||||
},
|
||||
"status": {"$ifNull": ["$status", "unknown"]},
|
||||
}
|
||||
}
|
||||
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"project": "$project",
|
||||
"status": "$status",
|
||||
archived: cls.archived_tasks_cond,
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
# for each project, create a list of (status, count, archived)
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.project",
|
||||
"counts": {
|
||||
"$push": {
|
||||
"status": "$_id.status",
|
||||
"count": "$count",
|
||||
archived: "$_id.%s" % archived,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed and started and completed > started
|
||||
"if": {
|
||||
"$and": [
|
||||
"$started",
|
||||
"$completed",
|
||||
{"$gt": ["$completed", "$started"]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
# then: floor((completed - started) / 1000)
|
||||
"then": {
|
||||
"$floor": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$completed", "$started"]},
|
||||
1000.0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery(
|
||||
{"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
},
|
||||
]
|
||||
|
||||
return status_count_pipeline, runtime_pipeline
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@staticmethod
|
||||
def aggregate_project_data(
|
||||
func: Callable[[T, T], T],
|
||||
project_ids: Sequence[str],
|
||||
child_projects: Mapping[str, Sequence[Project]],
|
||||
data: Mapping[str, T],
|
||||
) -> Dict[str, T]:
|
||||
"""
|
||||
Given a list of project ids and data collected over these projects and their subprojects
|
||||
For each project aggregates the data from all of its subprojects
|
||||
"""
|
||||
aggregated = {}
|
||||
if not data:
|
||||
return aggregated
|
||||
for pid in project_ids:
|
||||
relevant_projects = {p.id for p in child_projects.get(pid, [])} | {pid}
|
||||
relevant_data = [data for p, data in data.items() if p in relevant_projects]
|
||||
if not relevant_data:
|
||||
continue
|
||||
aggregated[pid] = reduce(func, relevant_data)
|
||||
return aggregated
|
||||
|
||||
@classmethod
|
||||
def get_project_stats(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
|
||||
project_ids_with_children = set(project_ids) | {
|
||||
c.id for c in itertools.chain.from_iterable(child_projects.values())
|
||||
}
|
||||
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
specific_state=specific_state,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
|
||||
def set_default_count(entry):
|
||||
return dict(default_counts, **entry)
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.aggregate(status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
).value
|
||||
status_count[result["_id"]][section] = set_default_count(
|
||||
{
|
||||
count_entry["status"]: count_entry["count"]
|
||||
for count_entry in group
|
||||
}
|
||||
)
|
||||
|
||||
def sum_status_count(
|
||||
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: {
|
||||
status: nested_get(a, (section, status), 0)
|
||||
+ nested_get(b, (section, status), 0)
|
||||
for status in set(a.get(section, {})) | set(b.get(section, {}))
|
||||
}
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
status_count = cls.aggregate_project_data(
|
||||
func=sum_status_count,
|
||||
project_ids=project_ids,
|
||||
child_projects=child_projects,
|
||||
data=status_count,
|
||||
)
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
}
|
||||
|
||||
def sum_runtime(
|
||||
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: a.get(section, 0) + b.get(section, 0)
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
runtime = cls.aggregate_project_data(
|
||||
func=sum_runtime,
|
||||
project_ids=project_ids,
|
||||
child_projects=child_projects,
|
||||
data=runtime,
|
||||
)
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
return {
|
||||
"total_runtime": nested_get(runtime, (project_id, section), 0),
|
||||
"status_count": nested_get(
|
||||
status_count, (project_id, section), default_counts
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
stats = {
|
||||
project: {
|
||||
task_state.value: get_status_counts(project, task_state.value)
|
||||
for task_state in report_for_states
|
||||
}
|
||||
for project in project_ids
|
||||
}
|
||||
|
||||
children = {
|
||||
project: sorted(
|
||||
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
|
||||
key=itemgetter("name"),
|
||||
)
|
||||
for project in project_ids
|
||||
}
|
||||
return stats, children
|
||||
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls,
|
||||
company,
|
||||
project_ids: Sequence[str],
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models/dataviews in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
query = Q(company=company)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
|
||||
projects_query = query
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
projects_query &= Q(id__in=project_ids)
|
||||
|
||||
res = set(Project.objects(projects_query).distinct(field="user"))
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
cls,
|
||||
company: str,
|
||||
users: Sequence[str],
|
||||
project_ids: Optional[Sequence[str]] = None,
|
||||
allow_public: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Get the projects ids where user created any tasks including all the parents of these projects
|
||||
If project ids are specified then filter the results by these project ids
|
||||
"""
|
||||
query = Q(user__in=users)
|
||||
|
||||
if allow_public:
|
||||
query &= get_company_or_none_constraint(company)
|
||||
else:
|
||||
query &= Q(company=company)
|
||||
|
||||
user_projects_query = query
|
||||
if project_ids:
|
||||
ids_with_children = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=ids_with_children)
|
||||
user_projects_query &= Q(id__in=ids_with_children)
|
||||
|
||||
res = {p.id for p in Project.objects(user_projects_query).only("id")}
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="project"))
|
||||
|
||||
res = list(res)
|
||||
if not res:
|
||||
return res
|
||||
|
||||
ids_with_parents = _ids_with_parents(res)
|
||||
if project_ids:
|
||||
return [pid for pid in ids_with_parents if pid in project_ids]
|
||||
|
||||
return ids_with_parents
|
||||
|
||||
@classmethod
|
||||
def get_task_parents(
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get list of unique parent tasks sorted by task name for the passed company projects
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
if projects:
|
||||
if include_subprojects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
|
||||
parent_ids = set(Task.objects(query).distinct("parent"))
|
||||
if not parent_ids:
|
||||
return []
|
||||
|
||||
parents = Task.get_many_with_join(
|
||||
company_id,
|
||||
query=Q(id__in=parent_ids),
|
||||
allow_public=True,
|
||||
override_projection=("id", "name", "project.name"),
|
||||
)
|
||||
return sorted(parents, key=itemgetter("name"))
|
||||
|
||||
@classmethod
|
||||
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@classmethod
|
||||
def get_model_frameworks(cls, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/dataviews/models per requested project
|
||||
Use separate aggregation calls on Task/Dataview/Model instead of lookup
|
||||
aggregation on projects in order not to hit memory limits on large tasks
|
||||
"""
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {"project": 1}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$project",
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
return {
|
||||
data["_id"]: data["count"]
|
||||
for data in cls_.aggregate(pipeline)
|
||||
}
|
||||
|
||||
with TimingContext("mongo", "get_security_groups"):
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {
|
||||
"own_tasks": tasks.get(pid, 0),
|
||||
"own_models": models.get(pid, 0),
|
||||
}
|
||||
for pid in project_ids
|
||||
}
|
||||
|
||||
154
apiserver/bll/project/project_cleanup.py
Normal file
154
apiserver/bll/project/project_cleanup.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from typing import Tuple, Set, Sequence
|
||||
|
||||
import attr
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.task.task_cleanup import (
|
||||
collect_debug_image_urls,
|
||||
collect_plot_image_urls,
|
||||
TaskUrls,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes
|
||||
from apiserver.timing_context import TimingContext
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DeleteProjectResult:
|
||||
deleted: int = 0
|
||||
disassociated_tasks: int = 0
|
||||
deleted_models: int = 0
|
||||
deleted_tasks: int = 0
|
||||
urls: TaskUrls = None
|
||||
|
||||
|
||||
def delete_project(
|
||||
company: str, project_id: str, force: bool, delete_contents: bool
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
project_ids = _ids_with_children([project_id])
|
||||
if not force:
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
non_archived = cls.objects(
|
||||
project__in=project_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).only("id")
|
||||
if non_archived:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
|
||||
if not delete_contents:
|
||||
with TimingContext("mongo", "update_children"):
|
||||
for cls in (Model, Task):
|
||||
updated_count = cls.objects(project__in=project_ids).update(
|
||||
project=None
|
||||
)
|
||||
res = DeleteProjectResult(disassociated_tasks=updated_count)
|
||||
else:
|
||||
deleted_models, model_urls = _delete_models(projects=project_ids)
|
||||
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
res = DeleteProjectResult(
|
||||
deleted_tasks=deleted_tasks,
|
||||
deleted_models=deleted_models,
|
||||
urls=TaskUrls(
|
||||
model_urls=list(model_urls),
|
||||
event_urls=list(event_urls),
|
||||
artifact_urls=list(artifact_urls),
|
||||
),
|
||||
)
|
||||
|
||||
affected = {*project_ids, *(project.path or [])}
|
||||
res.deleted = Project.objects(id__in=project_ids).delete()
|
||||
|
||||
return res, affected
|
||||
|
||||
|
||||
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
|
||||
"""
|
||||
Delete only the task themselves and their non published version.
|
||||
Child models under the same project are deleted separately.
|
||||
Children tasks should be deleted in the same api call.
|
||||
If any child entities are left in another projects then updated their parent task to None
|
||||
"""
|
||||
tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
|
||||
if not tasks:
|
||||
return 0, set(), set()
|
||||
|
||||
task_ids = {t.id for t in tasks}
|
||||
with TimingContext("mongo", "delete_tasks_update_children"):
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
|
||||
event_urls, artifact_urls = set(), set()
|
||||
for task in tasks:
|
||||
event_urls.update(collect_debug_image_urls(company, task.id))
|
||||
event_urls.update(collect_plot_image_urls(company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls.update(
|
||||
{
|
||||
a.uri
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
)
|
||||
|
||||
event_bll.delete_multi_task_events(company, list(task_ids))
|
||||
deleted = tasks.delete()
|
||||
return deleted, event_urls, artifact_urls
|
||||
|
||||
|
||||
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
|
||||
"""
|
||||
Delete project models and update the tasks from other projects
|
||||
that reference them to reference None.
|
||||
"""
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set()
|
||||
|
||||
model_ids = list({m.id for m in models})
|
||||
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.output.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
urls = {m.uri for m in models if m.uri}
|
||||
deleted = models.delete()
|
||||
return deleted, urls
|
||||
176
apiserver/bll/project/sub_projects.py
Normal file
176
apiserver/bll/project/sub_projects.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import itertools
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Optional, Sequence, Mapping
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.project import Project
|
||||
|
||||
name_separator = "/"
|
||||
|
||||
|
||||
def _get_project_depth(project_name: str) -> int:
|
||||
return len(list(filter(None, project_name.split(name_separator))))
|
||||
|
||||
|
||||
def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove redundant '/' characters. Ensure that the project name is not empty
|
||||
Return the cleaned up project name and location
|
||||
"""
|
||||
name_parts = list(filter(None, project_name.split(name_separator)))
|
||||
if not name_parts:
|
||||
raise errors.bad_request.InvalidProjectName(name=project_name)
|
||||
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
Return the project
|
||||
"""
|
||||
name = name.strip(name_separator)
|
||||
if not name:
|
||||
return None
|
||||
|
||||
project = _get_writable_project_from_name(company, name)
|
||||
if project:
|
||||
return project
|
||||
|
||||
now = datetime.utcnow()
|
||||
name, location = _validate_project_name(name)
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=user,
|
||||
company=company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
description="",
|
||||
)
|
||||
parent = _ensure_project(company, user, location)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def _save_under_parent(project: Project, parent: Optional[Project]):
|
||||
"""
|
||||
Save the project under the given parent project or top level (parent=None)
|
||||
Check that the project location matches the parent name
|
||||
"""
|
||||
location, _, _ = project.name.rpartition(name_separator)
|
||||
if not parent:
|
||||
if location:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match empty parent name"
|
||||
)
|
||||
project.parent = None
|
||||
project.path = []
|
||||
project.save()
|
||||
return
|
||||
|
||||
if location != parent.name:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match parent name {parent.name}"
|
||||
)
|
||||
project.parent = parent.id
|
||||
project.path = [*(parent.path or []), parent.id]
|
||||
project.save()
|
||||
|
||||
|
||||
def _get_writable_project_from_name(
|
||||
company,
|
||||
name,
|
||||
_only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"),
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Return a project from name. If the project not found then return None
|
||||
"""
|
||||
qs = Project.objects(company=company, name=name)
|
||||
if _only:
|
||||
qs = qs.only(*_only)
|
||||
return qs.first()
|
||||
|
||||
|
||||
def _get_sub_projects(
|
||||
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
|
||||
) -> Mapping[str, Sequence[Project]]:
|
||||
"""
|
||||
Return the list of child projects of all the levels for the parent project ids
|
||||
"""
|
||||
qs = Project.objects(path__in=project_ids)
|
||||
if _only:
|
||||
_only = set(_only) | {"path"}
|
||||
qs = qs.only(*_only)
|
||||
subprojects = list(qs)
|
||||
|
||||
return {
|
||||
pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids
|
||||
}
|
||||
|
||||
|
||||
def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with all the parent projects
|
||||
"""
|
||||
projects = Project.objects(id__in=project_ids).only("id", "path")
|
||||
parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path))
|
||||
return list({*(p.id for p in projects), *parent_ids})
|
||||
|
||||
|
||||
def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with the ids of all the subprojects
|
||||
"""
|
||||
subprojects = Project.objects(path__in=project_ids).only("id")
|
||||
return list({*project_ids, *(child.id for child in subprojects)})
|
||||
|
||||
|
||||
def _update_subproject_names(
|
||||
project: Project,
|
||||
children: Sequence[Project],
|
||||
old_name: str,
|
||||
update_path: bool = False,
|
||||
old_path: Sequence[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Update sub project names when the base project name changes
|
||||
Optionally update the paths
|
||||
"""
|
||||
updated = 0
|
||||
for child in children:
|
||||
child_suffix = name_separator.join(
|
||||
child.name.split(name_separator)[len(old_name.split(name_separator)) :]
|
||||
)
|
||||
updates = {"name": name_separator.join((project.name, child_suffix))}
|
||||
if update_path:
|
||||
updates["path"] = project.path + child.path[len(old_path) :]
|
||||
updated += child.update(upsert=False, **updates)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def _reposition_project_with_children(
|
||||
project: Project, children: Sequence[Project], parent: Project
|
||||
) -> int:
|
||||
new_location = parent.name if parent else None
|
||||
old_name = project.name
|
||||
old_path = project.path
|
||||
project.name = name_separator.join(
|
||||
filter(None, (new_location, project.name.split(name_separator)[-1]))
|
||||
)
|
||||
_save_under_parent(project, parent=parent)
|
||||
|
||||
moved = 1 + _update_subproject_names(
|
||||
project=project,
|
||||
children=children,
|
||||
old_name=old_name,
|
||||
update_path=True,
|
||||
old_path=old_path,
|
||||
)
|
||||
return moved
|
||||
@@ -32,6 +32,7 @@ class QueueBLL(object):
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Sequence[dict]] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
@@ -43,6 +44,7 @@ class QueueBLL(object):
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
system_tags=system_tags or [],
|
||||
metadata=metadata,
|
||||
last_update=now,
|
||||
)
|
||||
queue.save()
|
||||
|
||||
@@ -45,7 +45,7 @@ class StatisticsReporter:
|
||||
def start_reporter(cls):
|
||||
"""
|
||||
Periodically send statistics reports for companies who have opted in.
|
||||
Note: in trains we usually have only a single company
|
||||
Note: in clearml we usually have only a single company
|
||||
"""
|
||||
if not cls.supported:
|
||||
return
|
||||
|
||||
@@ -3,5 +3,4 @@ from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from hashlib import md5
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
@@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
|
||||
Calculate id from 'key' and 'mode' fields
|
||||
Return hash on on the id so that it will not contain mongo illegal characters
|
||||
"""
|
||||
key_hash: str = md5(artifact["key"].encode()).hexdigest()
|
||||
key_hash: str = hash_field_name(artifact["key"])
|
||||
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||
return f"{key_hash}_{mode}"
|
||||
|
||||
@@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
|
||||
nested_set(
|
||||
fields,
|
||||
artifacts_field,
|
||||
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
|
||||
value=sorted(artifacts.values(), key=itemgetter("key")),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -175,21 +175,23 @@ class HyperParams:
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
|
||||
) -> Dict[str, list]:
|
||||
with TimingContext("mongo", "get_configuration_names"):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
*([skip_empty_condition] if skip_empty else []),
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "get_configuration_names"):
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
|
||||
@@ -185,6 +185,7 @@ def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
("execution.docker_cmd", "container")
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
from six import string_types
|
||||
|
||||
import apiserver.database.utils as dbutils
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -21,21 +24,29 @@ from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
external_task_types,
|
||||
ModelItem,
|
||||
Models,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelNames,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change, update_project_time
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@@ -44,22 +55,9 @@ project_bll = ProjectBLL()
|
||||
|
||||
|
||||
class TaskBLL:
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
@@ -151,19 +149,20 @@ class TaskBLL:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_execution_model(task, allow_only_public=False):
|
||||
if not task.execution or not task.execution.model:
|
||||
def validate_input_models(task, allow_only_public=False):
|
||||
if not task.models.input:
|
||||
return
|
||||
|
||||
company = None if allow_only_public else task.company
|
||||
model_id = task.execution.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(model=model_id)
|
||||
model_ids = set(m.model for m in task.models.input)
|
||||
models = Model.objects(
|
||||
Q(id__in=model_ids) & get_company_or_none_constraint(company)
|
||||
).only("id")
|
||||
missing = model_ids - {m.id for m in models}
|
||||
if missing:
|
||||
raise errors.bad_request.InvalidModelId(models=missing)
|
||||
|
||||
return model
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def clone_task(
|
||||
@@ -179,7 +178,9 @@ class TaskBLL:
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
container: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
input_models: Optional[Sequence[TaskInputModel]] = None,
|
||||
validate_references: bool = False,
|
||||
new_project_name: str = None,
|
||||
) -> Tuple[Task, dict]:
|
||||
@@ -195,10 +196,29 @@ class TaskBLL:
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
now = datetime.utcnow()
|
||||
if input_models:
|
||||
input_models = [
|
||||
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
|
||||
]
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
execution_model = execution_overrides.pop("model", None)
|
||||
if not input_models and execution_model:
|
||||
input_models = [
|
||||
ModelItem(
|
||||
model=execution_model,
|
||||
name=TaskModelNames[TaskModelTypes.input],
|
||||
updated=now,
|
||||
)
|
||||
]
|
||||
|
||||
docker_cmd = execution_overrides.pop("docker_cmd", None)
|
||||
if not container and docker_cmd:
|
||||
image, _, arguments = docker_cmd.partition(" ")
|
||||
container = {"image": image, "arguments": arguments}
|
||||
|
||||
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||
|
||||
params_dict["execution"] = {}
|
||||
@@ -207,6 +227,8 @@ class TaskBLL:
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
|
||||
escape_dict_field(execution_overrides, "model_labels")
|
||||
|
||||
execution_dict.update(execution_overrides)
|
||||
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
@@ -216,7 +238,7 @@ class TaskBLL:
|
||||
execution_dict["artifacts"] = {
|
||||
k: a
|
||||
for k, a in artifacts.items()
|
||||
if a.get("mode") != ArtifactModes.output
|
||||
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
|
||||
}
|
||||
execution_dict.pop("queue", None)
|
||||
|
||||
@@ -227,12 +249,10 @@ class TaskBLL:
|
||||
project_name=new_project_name,
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
description="Auto-generated while cloning",
|
||||
description="",
|
||||
)
|
||||
new_project_data = {"id": project, "name": new_project_name}
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
|
||||
if not input_tags:
|
||||
return input_tags
|
||||
@@ -240,10 +260,16 @@ class TaskBLL:
|
||||
return [
|
||||
tag
|
||||
for tag in input_tags
|
||||
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
if tag
|
||||
not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "clone task"):
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(deleted_prefix)
|
||||
else None
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
@@ -253,7 +279,7 @@ class TaskBLL:
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
@@ -262,13 +288,15 @@ class TaskBLL:
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
@@ -295,7 +323,7 @@ class TaskBLL:
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_models=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
@@ -307,6 +335,7 @@ class TaskBLL:
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not task.parent.startswith(deleted_prefix)
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
@@ -318,16 +347,23 @@ class TaskBLL:
|
||||
if validate_project and not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
def get_unique_metric_variants(
|
||||
company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company={"$in": [None, "", company_id]},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
company={"$in": [None, "", company_id]}, **project_constraint,
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
@@ -372,6 +408,7 @@ class TaskBLL:
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"status", "started"
|
||||
)
|
||||
count = 0
|
||||
for task in tasks:
|
||||
updates = extra_updates
|
||||
if task.status == TaskStatus.in_progress and task.started:
|
||||
@@ -381,12 +418,13 @@ class TaskBLL:
|
||||
).total_seconds(),
|
||||
**extra_updates,
|
||||
}
|
||||
Task.objects(id=task.id, company=company_id).update(
|
||||
count += Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
**updates,
|
||||
)
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
@@ -449,168 +487,27 @@ class TaskBLL:
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
TaskBLL.set_last_update(
|
||||
return TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
last_update=last_update,
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def model_set_ready(
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
publish_task: bool,
|
||||
force_publish_task: bool = False,
|
||||
) -> tuple:
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
elif model.ready:
|
||||
raise errors.bad_request.ModelIsReady(**query)
|
||||
|
||||
published_task_data = {}
|
||||
if model.task and publish_task:
|
||||
task = (
|
||||
Task.objects(id=model.task, company=company_id)
|
||||
.only("id", "status")
|
||||
.first()
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
published_task_data["data"] = cls.publish_task(
|
||||
task_id=model.task,
|
||||
company_id=company_id,
|
||||
publish_model=False,
|
||||
force=force_publish_task,
|
||||
)
|
||||
published_task_data["id"] = model.task
|
||||
|
||||
updated = model.update(upsert=False, ready=True)
|
||||
return updated, published_task_data
|
||||
|
||||
@classmethod
|
||||
def publish_task(
|
||||
cls,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
publish_model: bool,
|
||||
force: bool,
|
||||
status_reason: str = "",
|
||||
status_message: str = "",
|
||||
) -> dict:
|
||||
task = cls.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
|
||||
previous_task_status = task.status
|
||||
output = task.output or Output()
|
||||
publish_failed = False
|
||||
|
||||
try:
|
||||
# set state to publishing
|
||||
task.status = TaskStatus.publishing
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.output.model and publish_model:
|
||||
output_model = (
|
||||
Model.objects(id=task.output.model)
|
||||
.only("id", "task", "ready")
|
||||
.first()
|
||||
)
|
||||
if output_model and not output_model.ready:
|
||||
cls.model_set_ready(
|
||||
model_id=task.output.model,
|
||||
company_id=company_id,
|
||||
publish_task=False,
|
||||
)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
force=force,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(published=datetime.utcnow(), output=output)
|
||||
|
||||
except Exception as ex:
|
||||
publish_failed = True
|
||||
raise ex
|
||||
finally:
|
||||
if publish_failed:
|
||||
task.status = previous_task_status
|
||||
task.save()
|
||||
|
||||
@classmethod
|
||||
def stop_task(
|
||||
cls,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Stop a running task. Requires task status 'in_progress' and
|
||||
execution_progress 'running', or force=True. Development task or
|
||||
task that has no associated worker is stopped immediately.
|
||||
For a non-development task with worker only the status message
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = cls.get_task_with_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
"""Checks if there is an active worker running the task"""
|
||||
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
|
||||
return (
|
||||
t.last_worker
|
||||
and t.last_update
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
new_status = task.status
|
||||
status_message = TaskStatusMessage.stopping
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=new_status,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
).execute()
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
@@ -618,7 +515,7 @@ class TaskBLL:
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
**project_constraint,
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
@@ -632,6 +529,8 @@ class TaskBLL:
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
@@ -639,16 +538,9 @@ class TaskBLL:
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"total": 1,
|
||||
"results": {"$slice": ["$results", page * page_size, page_size]},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
@@ -669,6 +561,106 @@ class TaskBLL:
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
if allow_public:
|
||||
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
||||
else:
|
||||
company_constraint = {"company": company_id}
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
@@ -677,10 +669,10 @@ class TaskBLL:
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
new_status=task.enqueue_status or TaskStatus.created,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(unset__execution__queue=1)
|
||||
).execute(enqueue_status=None)
|
||||
|
||||
@classmethod
|
||||
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
|
||||
|
||||
278
apiserver/bll/task/task_cleanup.py
Normal file
278
apiserver/bll/task/task_cleanup.py
Normal file
@@ -0,0 +1,278 @@
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import partition
|
||||
from mongoengine import QuerySet, Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_bll import PlotFields
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
event_bll = EventBLL()
|
||||
T = TypeVar("T", bound=Document)
|
||||
|
||||
|
||||
class DocumentGroup(List[T]):
|
||||
"""
|
||||
Operate on a list of documents as if they were a query result
|
||||
"""
|
||||
|
||||
def __init__(self, document_type: Type[T], documents: Iterable[T]):
|
||||
super(DocumentGroup, self).__init__(documents)
|
||||
self.type = document_type
|
||||
|
||||
@property
|
||||
def ids(self) -> Set[str]:
|
||||
return {obj.id for obj in self}
|
||||
|
||||
def objects(self, *args, **kwargs) -> QuerySet:
|
||||
return self.type.objects(id__in=self.ids, *args, **kwargs)
|
||||
|
||||
|
||||
class TaskOutputs(Generic[T]):
|
||||
"""
|
||||
Split task outputs of the same type by the ready state
|
||||
"""
|
||||
|
||||
published: DocumentGroup[T]
|
||||
draft: DocumentGroup[T]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_published: Callable[[T], bool],
|
||||
document_type: Type[T],
|
||||
children: Iterable[T],
|
||||
):
|
||||
"""
|
||||
:param is_published: predicate returning whether items is considered published
|
||||
:param document_type: type of output
|
||||
:param children: output documents
|
||||
"""
|
||||
self.published, self.draft = map(
|
||||
lambda x: DocumentGroup(document_type, x),
|
||||
partition(children, key=is_published),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskUrls:
|
||||
model_urls: Sequence[str]
|
||||
event_urls: Sequence[str]
|
||||
artifact_urls: Sequence[str]
|
||||
|
||||
def __add__(self, other: "TaskUrls"):
|
||||
if not other:
|
||||
return self
|
||||
|
||||
return TaskUrls(
|
||||
model_urls=list(set(self.model_urls) | set(other.model_urls)),
|
||||
event_urls=list(set(self.event_urls) | set(other.event_urls)),
|
||||
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class CleanupResult:
|
||||
"""
|
||||
Counts of objects modified in task cleanup operation
|
||||
"""
|
||||
|
||||
updated_children: int
|
||||
updated_models: int
|
||||
deleted_models: int
|
||||
urls: TaskUrls = None
|
||||
|
||||
def __add__(self, other: "CleanupResult"):
|
||||
if not other:
|
||||
return self
|
||||
|
||||
return CleanupResult(
|
||||
updated_children=self.updated_children + other.updated_children,
|
||||
updated_models=self.updated_models + other.updated_models,
|
||||
deleted_models=self.deleted_models + other.deleted_models,
|
||||
urls=self.urls + other.urls if self.urls else other.urls,
|
||||
)
|
||||
|
||||
|
||||
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
|
||||
urls = set()
|
||||
next_scroll_id = None
|
||||
with TimingContext("es", "collect_plot_image_urls"):
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
"""
|
||||
Return the set of unique image urls
|
||||
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
|
||||
"""
|
||||
metrics = event_bll.get_metrics_and_variants(
|
||||
company_id=company, task_id=task, event_type=EventType.metrics_image
|
||||
)
|
||||
if not metrics:
|
||||
return set()
|
||||
|
||||
task_metrics = {task: set(metrics)}
|
||||
scroll_id = None
|
||||
urls = set()
|
||||
while True:
|
||||
res = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company,
|
||||
task_metrics=task_metrics,
|
||||
iter_count=100,
|
||||
state_id=scroll_id,
|
||||
)
|
||||
if not res.metric_events or not any(
|
||||
iterations for _, iterations in res.metric_events
|
||||
):
|
||||
break
|
||||
|
||||
scroll_id = res.next_scroll_id
|
||||
for task, iterations in res.metric_events:
|
||||
urls.update(ev.get("url") for it in iterations for ev in it["events"])
|
||||
|
||||
urls.discard({None})
|
||||
return urls
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
task: Task,
|
||||
force: bool = False,
|
||||
update_children=True,
|
||||
return_file_urls=False,
|
||||
delete_output_models=True,
|
||||
) -> CleanupResult:
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
:param task: task object
|
||||
:param force: whether to delete task with published outputs
|
||||
:return: count of delete and modified items
|
||||
"""
|
||||
models = verify_task_children_and_ouptuts(task, force)
|
||||
|
||||
event_urls, artifact_urls, model_urls = set(), set(), set()
|
||||
if return_file_urls:
|
||||
event_urls = collect_debug_image_urls(task.company, task.id)
|
||||
event_urls.update(collect_plot_image_urls(task.company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls = {
|
||||
a.uri
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
|
||||
|
||||
deleted_task_id = f"{deleted_prefix}{task.id}"
|
||||
if update_children:
|
||||
with TimingContext("mongo", "update_task_children"):
|
||||
updated_children = Task.objects(parent=task.id).update(
|
||||
parent=deleted_task_id
|
||||
)
|
||||
else:
|
||||
updated_children = 0
|
||||
|
||||
if models.draft and delete_output_models:
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
deleted_models = models.draft.objects().delete()
|
||||
else:
|
||||
deleted_models = 0
|
||||
|
||||
if models.published and update_children:
|
||||
with TimingContext("mongo", "update_task_models"):
|
||||
updated_models = models.published.objects().update(task=deleted_task_id)
|
||||
else:
|
||||
updated_models = 0
|
||||
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
updated_children=updated_children,
|
||||
updated_models=updated_models,
|
||||
urls=TaskUrls(
|
||||
event_urls=list(event_urls),
|
||||
artifact_urls=list(artifact_urls),
|
||||
model_urls=list(model_urls),
|
||||
)
|
||||
if return_file_urls
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
|
||||
if not force:
|
||||
with TimingContext("mongo", "count_published_children"):
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
children=published_children_count,
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "get_task_models"):
|
||||
models = TaskOutputs(
|
||||
attrgetter("ready"),
|
||||
Model,
|
||||
Model.objects(task=task.id).only("id", "task", "ready"),
|
||||
)
|
||||
if not force and models.published:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
task=task.id,
|
||||
models=len(models.published),
|
||||
)
|
||||
|
||||
if task.models and task.models.output:
|
||||
with TimingContext("mongo", "get_task_output_model"):
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for output_model in Model.objects(id__in=model_ids):
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=output_model.id,
|
||||
)
|
||||
models.published.append(output_model)
|
||||
else:
|
||||
models.draft.append(output_model)
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = models.draft.ids
|
||||
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
|
||||
"id", "models"
|
||||
)
|
||||
input_models = {
|
||||
m.model
|
||||
for m in chain.from_iterable(
|
||||
t.models.input for t in dependent_tasks if t.models
|
||||
)
|
||||
}
|
||||
if input_models:
|
||||
models.draft = DocumentGroup(
|
||||
Model, (m for m in models.draft if m.id not in input_models)
|
||||
)
|
||||
|
||||
return models
|
||||
380
apiserver/bll/task/task_operations.py
Normal file
380
apiserver/bll/task/task_operations.py
Normal file
@@ -0,0 +1,380 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Any, Tuple, Union
|
||||
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
validate_status_change,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskStatus,
|
||||
Task,
|
||||
TaskSystemTags,
|
||||
TaskStatusMessage,
|
||||
ArtifactModes,
|
||||
Execution,
|
||||
DEFAULT_LAST_ITERATION,
|
||||
)
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
|
||||
def archive_task(
|
||||
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Deque and archive task
|
||||
Return 1 if successful
|
||||
"""
|
||||
if isinstance(task, str):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task,
|
||||
company_id=company_id,
|
||||
only=(
|
||||
"id",
|
||||
"execution",
|
||||
"status",
|
||||
"project",
|
||||
"system_tags",
|
||||
"enqueue_status",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task, company_id, status_message, status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def unarchive_task(
|
||||
task: str, company_id: str, status_message: str, status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Unarchive task. Return 1 if successful
|
||||
"""
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task, company_id=company_id, only=("id",), requires_write_access=True,
|
||||
)
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> Tuple[int, dict]:
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
return 1, res
|
||||
|
||||
|
||||
def enqueue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
validate: bool = False,
|
||||
) -> Tuple[int, dict]:
|
||||
if not queue_id:
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
if validate:
|
||||
TaskBLL.validate(task)
|
||||
|
||||
res = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.queued,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
).execute(enqueue_status=task.status)
|
||||
|
||||
try:
|
||||
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
|
||||
except Exception:
|
||||
# failed enqueueing, revert to previous state
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
current_status_override=TaskStatus.queued,
|
||||
new_status=task.status,
|
||||
force=True,
|
||||
status_reason="failed enqueueing",
|
||||
).execute(enqueue_status=None)
|
||||
raise
|
||||
|
||||
# set the current queue ID in the task
|
||||
if task.execution:
|
||||
Task.objects(**query).update(execution__queue=queue_id, multi=False)
|
||||
else:
|
||||
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
|
||||
|
||||
nested_set(res, ("fields", "execution.queue"), queue_id)
|
||||
return 1, res
|
||||
|
||||
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
if (
|
||||
task.status != TaskStatus.created
|
||||
and EntityVisibility.archived.value not in task.system_tags
|
||||
and not force
|
||||
):
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"due to status, use force=True",
|
||||
task=task.id,
|
||||
expected=TaskStatus.created,
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
cleanup_res = cleanup_task(
|
||||
task,
|
||||
force=force,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
)
|
||||
|
||||
if move_to_trash:
|
||||
collection_name = task._get_collection_name()
|
||||
archived_collection = "{}__trash".format(collection_name)
|
||||
task.switch_collection(archived_collection)
|
||||
try:
|
||||
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
|
||||
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
|
||||
task.save(force_insert=True)
|
||||
except Exception:
|
||||
pass
|
||||
task.switch_collection(collection_name)
|
||||
|
||||
task.delete()
|
||||
update_project_time(task.project)
|
||||
return 1, task, cleanup_res
|
||||
|
||||
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
||||
|
||||
dequeued = {}
|
||||
updates = {}
|
||||
|
||||
try:
|
||||
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
cleaned_up = cleanup_task(
|
||||
task,
|
||||
force=force,
|
||||
update_children=False,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
)
|
||||
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__metric_stats={},
|
||||
set__models__output=[],
|
||||
unset__output__result=1,
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
)
|
||||
|
||||
if clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(), unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
if task.execution and task.execution.artifacts:
|
||||
updates.update(
|
||||
set__execution__artifacts={
|
||||
key: artifact
|
||||
for key, artifact in task.execution.artifacts.items()
|
||||
if artifact.mode == ArtifactModes.input
|
||||
}
|
||||
)
|
||||
|
||||
res = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
force=force,
|
||||
status_reason="reset",
|
||||
status_message="reset",
|
||||
).execute(
|
||||
started=None,
|
||||
completed=None,
|
||||
published=None,
|
||||
active_duration=None,
|
||||
enqueue_status=None,
|
||||
**updates,
|
||||
)
|
||||
|
||||
return dequeued, cleaned_up, res
|
||||
|
||||
|
||||
def publish_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
force: bool,
|
||||
publish_model_func: Callable[[str, str], Any] = None,
|
||||
status_message: str = "",
|
||||
status_reason: str = "",
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
|
||||
previous_task_status = task.status
|
||||
output = task.output or Output()
|
||||
publish_failed = False
|
||||
|
||||
try:
|
||||
# set state to publishing
|
||||
task.status = TaskStatus.publishing
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.models and task.models.output and publish_model_func:
|
||||
model_id = task.models.output[-1].model
|
||||
model = (
|
||||
Model.objects(id=model_id, company=company_id)
|
||||
.only("id", "ready")
|
||||
.first()
|
||||
)
|
||||
if model and not model.ready:
|
||||
publish_model_func(model.id, company_id)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
force=force,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(published=datetime.utcnow(), output=output)
|
||||
|
||||
except Exception as ex:
|
||||
publish_failed = True
|
||||
raise ex
|
||||
finally:
|
||||
if publish_failed:
|
||||
task.status = previous_task_status
|
||||
task.save()
|
||||
|
||||
|
||||
def stop_task(
|
||||
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Stop a running task. Requires task status 'in_progress' and
|
||||
execution_progress 'running', or force=True. Development task or
|
||||
task that has no associated worker is stopped immediately.
|
||||
For a non-development task with worker only the status message
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
"""Checks if there is an active worker running the task"""
|
||||
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
|
||||
return (
|
||||
t.last_worker
|
||||
and t.last_update
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
new_status = task.status
|
||||
status_message = TaskStatusMessage.stopping
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=new_status,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
).execute()
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Callable, Tuple, Sequence, Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
import attr
|
||||
import six
|
||||
@@ -13,6 +13,7 @@ from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
deleted_prefix = "__DELETED__"
|
||||
|
||||
|
||||
@typed_attrs
|
||||
@@ -105,7 +106,7 @@ def validate_status_change(current_status, new_status):
|
||||
|
||||
state_machine = {
|
||||
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
|
||||
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
|
||||
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
|
||||
TaskStatus.in_progress: {
|
||||
TaskStatus.stopped,
|
||||
TaskStatus.failed,
|
||||
@@ -116,6 +117,7 @@ state_machine = {
|
||||
TaskStatus.closed,
|
||||
TaskStatus.created,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.queued,
|
||||
TaskStatus.in_progress,
|
||||
TaskStatus.published,
|
||||
TaskStatus.publishing,
|
||||
@@ -163,22 +165,6 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def split_by(
|
||||
condition: Callable[[T], bool], items: Sequence[T]
|
||||
) -> Tuple[Sequence[T], Sequence[T]]:
|
||||
"""
|
||||
split "items" to two lists by "condition"
|
||||
"""
|
||||
applied = zip(map(condition, items), items)
|
||||
return (
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
)
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
) -> Task:
|
||||
|
||||
@@ -1,33 +1,25 @@
|
||||
import functools
|
||||
import itertools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
Dict,
|
||||
Any,
|
||||
Set,
|
||||
Iterable,
|
||||
Tuple,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.settings import Settings
|
||||
|
||||
|
||||
def extract_properties_to_lists(
|
||||
key_names: Sequence[str],
|
||||
data: Sequence[dict],
|
||||
extract_func: Optional[Callable[[dict], Tuple]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Given a list of dictionaries and names of dictionary keys
|
||||
builds a dictionary with the requested keys and values lists
|
||||
:param key_names: names of the keys in the resulting dictionary
|
||||
:param data: sequence of dictionaries to extract values from
|
||||
:param extract_func: the optional callable that extracts properties
|
||||
from a dictionary and put them in a tuple in the order corresponding to
|
||||
key_names. If not specified then properties are extracted according to key_names
|
||||
"""
|
||||
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
|
||||
return dict(zip(key_names, map(list, value_sequences)))
|
||||
|
||||
|
||||
class SetFieldsResolver:
|
||||
"""
|
||||
The class receives set fields dictionary
|
||||
@@ -115,3 +107,28 @@ def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_batch_operation(
|
||||
func: Callable[[str], T], ids: Sequence[str]
|
||||
) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
|
||||
results = list()
|
||||
failures = list()
|
||||
for _id in ids:
|
||||
try:
|
||||
results.append((_id, func(_id)))
|
||||
except APIError as err:
|
||||
failures.append(
|
||||
{
|
||||
"id": _id,
|
||||
"error": {
|
||||
"codes": [err.code, err.subcode],
|
||||
"msg": err.msg,
|
||||
"data": err.error_data,
|
||||
},
|
||||
}
|
||||
)
|
||||
return results, failures
|
||||
|
||||
@@ -6,9 +6,10 @@ from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
from pathlib import Path
|
||||
from typing import List, Any, TypeVar
|
||||
from typing import List, Any, TypeVar, Sequence
|
||||
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from boltons.iterutils import first
|
||||
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
|
||||
from pyparsing import (
|
||||
ParseFatalException,
|
||||
ParseException,
|
||||
@@ -18,8 +19,8 @@ from pyparsing import (
|
||||
|
||||
from apiserver.utilities import json
|
||||
|
||||
EXTRA_CONFIG_PATHS = ("/opt/trains/config",)
|
||||
EXTRA_CONFIG_PATH_OVERRIDE_VAR = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATHS = ("/opt/trains/config", "/opt/clearml/config")
|
||||
DEFAULT_PREFIXES = ("clearml", "trains")
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"
|
||||
|
||||
|
||||
@@ -30,7 +31,10 @@ class BasicConfig:
|
||||
default_config_dir = "default"
|
||||
|
||||
def __init__(
|
||||
self, folder: str = None, verbose: bool = True, prefix: str = "trains"
|
||||
self,
|
||||
folder: str = None,
|
||||
verbose: bool = True,
|
||||
prefix: Sequence[str] = DEFAULT_PREFIXES,
|
||||
):
|
||||
folder = (
|
||||
Path(folder)
|
||||
@@ -41,8 +45,16 @@ class BasicConfig:
|
||||
raise ValueError("Invalid configuration folder")
|
||||
|
||||
self.verbose = verbose
|
||||
self.prefix = prefix
|
||||
self.extra_config_values_env_key_prefix = f"{self.prefix.upper()}__"
|
||||
|
||||
self.extra_config_path_override_var = [
|
||||
f"{p.upper()}_CONFIG_DIR" for p in prefix
|
||||
]
|
||||
|
||||
self.prefix = prefix[0]
|
||||
self.extra_config_values_env_key_prefix = [
|
||||
f"{p.upper()}{self.extra_config_values_env_key_sep}"
|
||||
for p in reversed(prefix)
|
||||
]
|
||||
|
||||
self._paths = [folder, *self._get_paths()]
|
||||
self._config = self._reload()
|
||||
@@ -73,24 +85,24 @@ class BasicConfig:
|
||||
def _read_extra_env_config_values(self) -> ConfigTree:
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = self.extra_config_values_env_key_prefix
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = (
|
||||
key[len(prefix) :]
|
||||
.replace(self.extra_config_values_env_key_sep, ".")
|
||||
.lower()
|
||||
)
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
for prefix in self.extra_config_values_env_key_prefix:
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = (
|
||||
key[len(prefix) :]
|
||||
.replace(self.extra_config_values_env_key_sep, ".")
|
||||
.lower()
|
||||
)
|
||||
result = self._merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _get_paths(self) -> List[Path]:
|
||||
default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS)
|
||||
value = getenv(EXTRA_CONFIG_PATH_OVERRIDE_VAR, default_paths)
|
||||
value = first(map(getenv, self.extra_config_path_override_var), default_paths)
|
||||
|
||||
paths = [
|
||||
Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP)
|
||||
@@ -100,7 +112,7 @@ class BasicConfig:
|
||||
invalid = [path for path in paths if not path.is_dir()]
|
||||
if invalid:
|
||||
print(
|
||||
f"WARNING: Invalid paths in {EXTRA_CONFIG_PATH_OVERRIDE_VAR} env var: {' '.join(map(str, invalid))}"
|
||||
f"WARNING: Invalid paths in {self.extra_config_path_override_var} env var: {' '.join(map(str, invalid))}"
|
||||
)
|
||||
|
||||
return [path for path in paths if path.is_dir()]
|
||||
@@ -114,13 +126,40 @@ class BasicConfig:
|
||||
configs = [self._read_recursive(path) for path in self._paths]
|
||||
|
||||
return reduce(
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
lambda last, config: self._merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _merge_configs(cls, a, b, copy_trees=False, override_prefix="-"):
|
||||
"""Based on pyhocon.ConfigTree.merge_configs, with dict override support using a `-` key prefix"""
|
||||
for key, value in b.items():
|
||||
override = key.startswith(override_prefix)
|
||||
if override:
|
||||
key = key[len(override_prefix):]
|
||||
# if key is in both a and b and both values are dictionary then merge it otherwise override it
|
||||
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
|
||||
if copy_trees:
|
||||
a[key] = a[key].copy()
|
||||
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
|
||||
else:
|
||||
if isinstance(value, ConfigValues):
|
||||
value.parent = a
|
||||
value.key = key
|
||||
if key in a:
|
||||
value.overriden_value = a[key]
|
||||
a[key] = value
|
||||
if a.root:
|
||||
if b.root:
|
||||
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
|
||||
else:
|
||||
a.history[key] = a.history.get(key, []) + [value]
|
||||
|
||||
return a
|
||||
|
||||
def _read_recursive(self, conf_root) -> ConfigTree:
|
||||
conf = ConfigTree()
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@
|
||||
default_expiration_sec: 2592000
|
||||
|
||||
# cookie containing auth token, for requests arriving from a web-browser
|
||||
session_auth_cookie_name: "trains_token_basic"
|
||||
session_auth_cookie_name: "clearml_token_basic"
|
||||
|
||||
# cookie configuration for authorization cookies generated by auth.login
|
||||
cookies {
|
||||
@@ -80,8 +80,10 @@
|
||||
}
|
||||
|
||||
# # A list of fixed users
|
||||
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
|
||||
# fixed_users {
|
||||
# enabled: true
|
||||
# pass_hashed: false
|
||||
# users: [
|
||||
# {
|
||||
# username: "john"
|
||||
@@ -116,9 +118,9 @@
|
||||
# Check for updates every 24 hours
|
||||
check_interval_sec: 86400
|
||||
|
||||
url: "https://updates.trains.allegro.ai/updates"
|
||||
url: "https://updates.clear.ml/updates"
|
||||
|
||||
component_name: "trains-server"
|
||||
component_name: "clearml-server"
|
||||
|
||||
# GET request timeout
|
||||
request_timeout_sec: 3.0
|
||||
@@ -128,7 +130,7 @@
|
||||
# Note: statistics are sent ONLY if the user has actively opted-in
|
||||
supported: true
|
||||
|
||||
url: "https://updates.trains.allegro.ai/stats"
|
||||
url: "https://updates.clear.ml/stats"
|
||||
|
||||
report_interval_hours: 24
|
||||
agent_relevant_threshold_days: 30
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
backupCount: 3
|
||||
maxBytes: 10240000,
|
||||
class: "logging.handlers.RotatingFileHandler",
|
||||
filename: "/var/log/trains/apiserver.log"
|
||||
filename: "/var/log/clearml/apiserver.log"
|
||||
}
|
||||
}
|
||||
root {
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
display_name: "Default User"
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,4 +10,9 @@ featured {
|
||||
|
||||
# default featured index for public projects not specified in the order
|
||||
public_default: 9999
|
||||
}
|
||||
|
||||
sub_projects {
|
||||
# the max sub project depth
|
||||
max_depth: 10
|
||||
}
|
||||
@@ -9,3 +9,14 @@ non_responsive_tasks_watchdog {
|
||||
}
|
||||
|
||||
multi_task_histogram_limit: 100
|
||||
|
||||
hyperparam_values {
|
||||
# maximal amount of distinct hyperparam values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# max allowed outdate time for the cashed result
|
||||
cache_allowed_outdate_sec: 60
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
@@ -2,6 +2,8 @@ from functools import lru_cache
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.version import __version__
|
||||
|
||||
@@ -9,7 +11,9 @@ root = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def _get(prop_name, env_suffix=None, default=""):
|
||||
value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}")
|
||||
suffix = env_suffix or prop_name
|
||||
keys = [f"{p}_SERVER_{suffix}" for p in ("CLEARML", "TRAINS")]
|
||||
value = first(map(getenv, keys))
|
||||
if value:
|
||||
return value
|
||||
|
||||
|
||||
@@ -17,11 +17,16 @@ log = config.logger("database")
|
||||
strict = config.get("apiserver.mongo.strict", True)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_MONGODB_SERVICE_HOST",
|
||||
"TRAINS_MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
|
||||
OVERRIDE_PORT_ENV_KEY = (
|
||||
"CLEARML_MONGODB_SERVICE_PORT",
|
||||
"TRAINS_MONGODB_SERVICE_PORT",
|
||||
"MONGODB_SERVICE_PORT",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
@@ -32,6 +37,10 @@ class DatabaseEntry(models.Base):
|
||||
class DatabaseFactory:
|
||||
_entries = []
|
||||
|
||||
@classmethod
|
||||
def _create_db_entry(cls, alias: str, settings: dict) -> DatabaseEntry:
|
||||
return DatabaseEntry(alias=alias, **settings)
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
db_entries = config.get("hosts.mongo", {})
|
||||
@@ -51,7 +60,7 @@ class DatabaseFactory:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
@@ -64,13 +73,15 @@ class DatabaseFactory:
|
||||
log.info(
|
||||
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
|
||||
)
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
register_connection(**entry.to_struct())
|
||||
|
||||
cls._entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
|
||||
raise ValueError(
|
||||
"Missing database configuration for %s" % ", ".join(missing)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_entries(cls):
|
||||
@@ -91,7 +102,7 @@ class DatabaseFactory:
|
||||
# reconnection from work so workaround this
|
||||
# get_connection(entry.alias, reconnect=True)
|
||||
disconnect(entry.alias)
|
||||
register_connection(alias=entry.alias, host=entry.host)
|
||||
register_connection(**entry.to_struct())
|
||||
get_connection(entry.alias)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
@@ -86,6 +86,7 @@ class GetMixin(PropsMixin):
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
datetime_fields=None,
|
||||
fields=None,
|
||||
range_fields=None,
|
||||
):
|
||||
"""
|
||||
:param pattern_fields: Fields for which a "string contains" condition should be generated
|
||||
@@ -97,6 +98,7 @@ class GetMixin(PropsMixin):
|
||||
self.fields = fields
|
||||
self.datetime_fields = datetime_fields
|
||||
self.list_fields = list_fields
|
||||
self.range_fields = range_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
@@ -183,6 +185,53 @@ class GetMixin(PropsMixin):
|
||||
parameters, parameters_options
|
||||
) & cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
|
||||
@staticmethod
|
||||
def _pop_matching_params(
|
||||
patterns: Sequence[str], parameters: dict
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Pop the parameters that match the specified patterns and return
|
||||
the dictionary of matching parameters
|
||||
Pop None parameters since they are not the real queries
|
||||
"""
|
||||
if not patterns:
|
||||
return {}
|
||||
|
||||
fields = set()
|
||||
for pattern in patterns:
|
||||
if pattern.endswith("*"):
|
||||
prefix = pattern[:-1]
|
||||
fields.update(
|
||||
{field for field in parameters if field.startswith(prefix)}
|
||||
)
|
||||
elif pattern in parameters:
|
||||
fields.add(pattern)
|
||||
|
||||
pairs = ((field, parameters.pop(field, None)) for field in fields)
|
||||
return {k: v for k, v in pairs if v is not None}
|
||||
|
||||
@classmethod
|
||||
def _try_convert_to_numeric(cls, value: Union[str, Sequence[str]]):
|
||||
def convert_str(val: str) -> Union[float, str]:
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
if isinstance(value, str):
|
||||
return convert_str(value)
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [convert_str(v) if isinstance(v, str) else v for v in value]
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _get_fixed_field_value(cls, field: str, value):
|
||||
if field.startswith("last_metrics."):
|
||||
return cls._try_convert_to_numeric(value)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _prepare_query_no_company(
|
||||
cls, parameters=None, parameters_options=QueryParameterOptions()
|
||||
@@ -205,17 +254,24 @@ class GetMixin(PropsMixin):
|
||||
dict_query = {}
|
||||
query = RegexQ()
|
||||
if parameters:
|
||||
parameters = parameters.copy()
|
||||
parameters = {
|
||||
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
|
||||
}
|
||||
opts = parameters_options
|
||||
for field in opts.pattern_fields:
|
||||
pattern = parameters.pop(field, None)
|
||||
if pattern:
|
||||
dict_query[field] = RegexWrapper(pattern)
|
||||
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.list_fields, parameters=parameters
|
||||
).items():
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.range_fields, parameters=parameters
|
||||
).items():
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
@@ -250,15 +306,53 @@ class GetMixin(PropsMixin):
|
||||
raise MakeGetAllQueryError("incorrect field format", field)
|
||||
if not data.fields:
|
||||
break
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
)
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
|
||||
data.fields,
|
||||
Q()
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
Return a range query for the provided field. The data should contain min and max values
|
||||
Both intervals are included. For open range queries either min or max can be None
|
||||
In case the min value is None the records with missing or None value from db are included
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)) or len(data) != 2:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Min and max values should be specified for range field {field}"
|
||||
)
|
||||
|
||||
min_val, max_val = data
|
||||
if min_val is None and max_val is None:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"At least one of min or max values should be provided for field {field}"
|
||||
)
|
||||
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
query = {}
|
||||
if min_val is not None:
|
||||
query[f"{mongoengine_field}__gte"] = min_val
|
||||
if max_val is not None:
|
||||
query[f"{mongoengine_field}__lte"] = max_val
|
||||
|
||||
q = Q(**query)
|
||||
if min_val is None:
|
||||
q |= Q(**{mongoengine_field: None})
|
||||
|
||||
return q
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
@@ -271,7 +365,8 @@ class GetMixin(PropsMixin):
|
||||
- AND can be achieved using a preceding "__$all" or "__$and" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
data = [data]
|
||||
# raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
@@ -285,11 +380,7 @@ class GetMixin(PropsMixin):
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
@@ -448,6 +539,12 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return helper.project(results, projection_func)
|
||||
|
||||
@classmethod
|
||||
def _get_collation_override(cls, field: str) -> Optional[dict]:
|
||||
return first(
|
||||
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_many(
|
||||
cls,
|
||||
@@ -485,6 +582,13 @@ class GetMixin(PropsMixin):
|
||||
:param allow_public: If True, objects marked as public (no associated company) are also queried.
|
||||
:return: A list of objects matching the query.
|
||||
"""
|
||||
override_collation = None
|
||||
if query_dict:
|
||||
for field in query_dict:
|
||||
override_collation = cls._get_collation_override(field)
|
||||
if override_collation:
|
||||
break
|
||||
|
||||
if query_dict is not None:
|
||||
q = cls.prepare_query(
|
||||
parameters=query_dict,
|
||||
@@ -501,10 +605,14 @@ class GetMixin(PropsMixin):
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query, parameters=parameters, override_projection=override_projection
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -528,6 +636,7 @@ class GetMixin(PropsMixin):
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
override_collation=None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
@@ -547,12 +656,16 @@ class GetMixin(PropsMixin):
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
if order_by and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_by[0])
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if override_collation:
|
||||
qs = qs.collation(collation=override_collation)
|
||||
if search_text:
|
||||
qs = qs.search_text(search_text)
|
||||
if order_by:
|
||||
@@ -572,12 +685,41 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
def _get_queries_for_order_field(
|
||||
cls, query: Q, order_field: str
|
||||
) -> Union[None, Tuple[Q, Q]]:
|
||||
"""
|
||||
In case the order_field is one of the cls fields and the sorting is ascending
|
||||
then return the tuple of 2 queries:
|
||||
1. original query with not empty constraint on the order_by field
|
||||
2. original query with empty constraint on the order_by field
|
||||
"""
|
||||
if not order_field or order_field.startswith("-") or "[" in order_field:
|
||||
return
|
||||
|
||||
mongo_field_name = order_field.replace(".", "__")
|
||||
mongo_field = first(
|
||||
v for k, v in cls.get_all_fields_with_instance() if k == mongo_field_name
|
||||
)
|
||||
|
||||
if isinstance(mongo_field, ListField):
|
||||
params = {"is_list": True}
|
||||
elif isinstance(mongo_field, StringField):
|
||||
params = {"empty_value": ""}
|
||||
else:
|
||||
params = {}
|
||||
non_empty = query & field_exists(mongo_field_name, **params)
|
||||
empty = query & field_does_not_exist(mongo_field_name, **params)
|
||||
return non_empty, empty
|
||||
|
||||
@classmethod
|
||||
def _get_many_override_none_ordering(
|
||||
cls: Union[Document, "GetMixin"],
|
||||
query: Q = None,
|
||||
parameters: dict = None,
|
||||
override_projection: Collection[str] = None,
|
||||
override_collation: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
@@ -610,32 +752,17 @@ class GetMixin(PropsMixin):
|
||||
order_field = first(
|
||||
field for field in order_by if not field.startswith("$")
|
||||
)
|
||||
if (
|
||||
order_field
|
||||
and not order_field.startswith("-")
|
||||
and "[" not in order_field
|
||||
):
|
||||
params = {}
|
||||
mongo_field = order_field.replace(".", "__")
|
||||
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
|
||||
params["is_list"] = True
|
||||
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
|
||||
params["empty_value"] = ""
|
||||
non_empty = query & field_exists(mongo_field, **params)
|
||||
empty = query & field_does_not_exist(mongo_field, **params)
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
res = cls._get_queries_for_order_field(query, order_field)
|
||||
if res:
|
||||
query_sets = [cls.objects(q) for q in res]
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
if order_field:
|
||||
collation_override = first(
|
||||
v
|
||||
for k, v in cls._field_collation_overrides.items()
|
||||
if order_field.startswith(k)
|
||||
)
|
||||
if collation_override:
|
||||
query_sets = [
|
||||
qs.collation(collation=collation_override) for qs in query_sets
|
||||
]
|
||||
if order_field and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_field)
|
||||
|
||||
if override_collation:
|
||||
query_sets = [
|
||||
qs.collation(collation=override_collation) for qs in query_sets
|
||||
]
|
||||
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
44
apiserver/database/model/metadata.py
Normal file
44
apiserver/database/model/metadata.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Sequence, Type
|
||||
|
||||
from mongoengine import EmbeddedDocument, StringField, Document
|
||||
from pymongo import UpdateOne
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from apiserver.database.model.base import ProperDictMixin
|
||||
|
||||
|
||||
class MetadataItem(EmbeddedDocument, ProperDictMixin):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
|
||||
|
||||
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
|
||||
collection: Collection = cls._get_collection()
|
||||
res = collection.update_one(
|
||||
filter={"_id": _id},
|
||||
update={
|
||||
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
|
||||
},
|
||||
array_filters=[
|
||||
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
|
||||
],
|
||||
upsert=False,
|
||||
)
|
||||
if len(items) == 1 and res.modified_count == 1:
|
||||
return res.modified_count
|
||||
|
||||
requests = [
|
||||
UpdateOne(
|
||||
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
|
||||
update={"$push": {"metadata": item}},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
res = collection.bulk_write(requests)
|
||||
|
||||
return 1 if res.modified_count else 0
|
||||
|
||||
|
||||
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
|
||||
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)
|
||||
@@ -1,9 +1,22 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentListField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
@@ -19,6 +32,9 @@ class Model(DbModelMixin, Document):
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
"last_update",
|
||||
"metadata.key",
|
||||
"metadata.type",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
@@ -51,6 +67,7 @@ class Model(DbModelMixin, Document):
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
datetime_fields=("last_update",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -69,7 +86,11 @@ class Model(DbModelMixin, Document):
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
last_update = DateTimeField()
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
from mongoengine import StringField, DateTimeField, IntField, ListField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
@@ -10,13 +10,15 @@ class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"parent",
|
||||
"path",
|
||||
("company", "name"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
@@ -34,7 +36,7 @@ class Project(AttributedDocument):
|
||||
min_length=3,
|
||||
sparse=True,
|
||||
)
|
||||
description = StringField(required=True)
|
||||
description = StringField()
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
@@ -44,3 +46,5 @@ class Project(AttributedDocument):
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
parent = StringField(reference_field="Project")
|
||||
path = ListField(StringField(required=True), exclude_by_default=True)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
@@ -11,6 +13,7 @@ from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
@@ -32,6 +35,7 @@ class Queue(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"indexes": ["metadata.key", "metadata.type"],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -44,3 +48,6 @@ class Queue(DbModelMixin, Document):
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -11,6 +11,5 @@ class Result(object):
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = StrippedStringField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict
|
||||
from typing import Dict, Sequence
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
@@ -17,6 +17,7 @@ from apiserver.database.fields import (
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
SafeSortedListField,
|
||||
EmbeddedDocumentListField,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
@@ -79,7 +80,9 @@ DEFAULT_ARTIFACT_MODE = ArtifactModes.output
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
|
||||
mode = StringField(
|
||||
choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
@@ -103,17 +106,37 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
description = StringField()
|
||||
|
||||
|
||||
class TaskModelTypes:
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
TaskModelNames = {
|
||||
TaskModelTypes.input: "Input Model",
|
||||
TaskModelTypes.output: "Output Model",
|
||||
}
|
||||
|
||||
|
||||
class ModelItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True, reference_field="Model")
|
||||
updated = DateTimeField()
|
||||
|
||||
|
||||
class Models(EmbeddedDocument, ProperDictMixin):
|
||||
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
||||
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
|
||||
docker_cmd = StringField()
|
||||
queue = StringField()
|
||||
queue = StringField(reference_field="Queue")
|
||||
""" Queue ID where task was queued """
|
||||
|
||||
|
||||
@@ -140,7 +163,6 @@ class Task(AttributedDocument):
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@@ -153,6 +175,7 @@ class Task(AttributedDocument):
|
||||
"active_duration",
|
||||
"parent",
|
||||
"project",
|
||||
"models.input.model",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "status", "type"),
|
||||
@@ -160,14 +183,15 @@ class Task(AttributedDocument):
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{"fields": ["company", "project"], "collation": _numeric_locale},
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
"$name",
|
||||
"$id",
|
||||
"$comment",
|
||||
"$execution.model",
|
||||
"$output.model",
|
||||
"$models.input.model",
|
||||
"$models.output.model",
|
||||
"$script.repository",
|
||||
"$script.entry_point",
|
||||
],
|
||||
@@ -176,8 +200,8 @@ class Task(AttributedDocument):
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"execution.model": 2,
|
||||
"output.model": 2,
|
||||
"models.output.model": 2,
|
||||
"models.input.model": 2,
|
||||
"script.repository": 1,
|
||||
"script.entry_point": 1,
|
||||
},
|
||||
@@ -185,8 +209,18 @@ class Task(AttributedDocument):
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"),
|
||||
datetime_fields=("status_changed",),
|
||||
list_fields=(
|
||||
"id",
|
||||
"user",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"type",
|
||||
"status",
|
||||
"project",
|
||||
"parent",
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
pattern_fields=("name", "comment"),
|
||||
)
|
||||
|
||||
@@ -225,7 +259,11 @@ class Task(AttributedDocument):
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
docker_init_script = StringField()
|
||||
models: Models = EmbeddedDocumentField(Models, default=Models)
|
||||
container = SafeMapField(field=StringField(default=""))
|
||||
enqueue_status = StringField(
|
||||
choices=get_options(TaskStatus), exclude_by_default=True
|
||||
)
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from collections import OrderedDict, defaultdict
|
||||
from itertools import chain
|
||||
from collections import OrderedDict
|
||||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document, BaseField
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from apiserver.database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
@@ -21,7 +20,7 @@ class PropsMixin(object):
|
||||
__cached_reference_fields = None
|
||||
__cached_exclude_fields = None
|
||||
__cached_fields_with_instance = None
|
||||
__cached_field_names_per_type = None
|
||||
__cached_all_fields_with_instance = None
|
||||
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
@@ -33,37 +32,12 @@ class PropsMixin(object):
|
||||
return cls.__cached_fields
|
||||
|
||||
@classmethod
|
||||
def get_field_names_for_type(cls, of_type=BaseField):
|
||||
"""
|
||||
Return field names per type including subfields
|
||||
The fields of derived types are also returned
|
||||
"""
|
||||
assert issubclass(of_type, BaseField)
|
||||
if cls.__cached_field_names_per_type is None:
|
||||
fields = defaultdict(list)
|
||||
for name, field in get_fields(cls, return_instance=True, subfields=True):
|
||||
fields[type(field)].append(name)
|
||||
for type_ in fields:
|
||||
fields[type_].extend(
|
||||
chain.from_iterable(
|
||||
fields[other_type]
|
||||
for other_type in fields
|
||||
if other_type != type_ and issubclass(other_type, type_)
|
||||
)
|
||||
)
|
||||
cls.__cached_field_names_per_type = fields
|
||||
|
||||
if of_type not in cls.__cached_field_names_per_type:
|
||||
names = list(
|
||||
chain.from_iterable(
|
||||
field_names
|
||||
for type_, field_names in cls.__cached_field_names_per_type.items()
|
||||
if issubclass(type_, of_type)
|
||||
)
|
||||
def get_all_fields_with_instance(cls):
|
||||
if cls.__cached_all_fields_with_instance is None:
|
||||
cls.__cached_all_fields_with_instance = get_fields(
|
||||
cls, return_instance=True, subfields=True
|
||||
)
|
||||
cls.__cached_field_names_per_type[of_type] = names
|
||||
|
||||
return cls.__cached_field_names_per_type[of_type]
|
||||
return cls.__cached_all_fields_with_instance
|
||||
|
||||
@classmethod
|
||||
def get_fields_with_instance(cls, doc_cls):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from os import getenv
|
||||
from typing import Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch, Transport
|
||||
@@ -9,11 +10,16 @@ from apiserver.config_repo import config
|
||||
log = config.logger(__file__)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_ELASTIC_SERVICE_HOST",
|
||||
"TRAINS_ELASTIC_SERVICE_HOST",
|
||||
"ELASTIC_SERVICE_HOST",
|
||||
"ELASTIC_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_ELASTIC_SERVICE_PORT", "ELASTIC_SERVICE_PORT")
|
||||
OVERRIDE_PORT_ENV_KEY = (
|
||||
"CLEARML_ELASTIC_SERVICE_PORT",
|
||||
"TRAINS_ELASTIC_SERVICE_PORT",
|
||||
"ELASTIC_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
@@ -70,6 +76,10 @@ class ESFactory:
|
||||
def get_all_cluster_names(cls):
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
@classmethod
|
||||
def get_override(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
return OVERRIDE_HOST, OVERRIDE_PORT
|
||||
|
||||
@classmethod
|
||||
def get_cluster_config(cls, cluster_name):
|
||||
"""
|
||||
@@ -84,14 +94,16 @@ class ESFactory:
|
||||
raise MissingClusterConfiguration(cluster_name)
|
||||
|
||||
def set_host_prop(key, value):
|
||||
for host in cluster_config.get("hosts", []):
|
||||
host[key] = value
|
||||
for entry in cluster_config.get("hosts", []):
|
||||
entry[key] = value
|
||||
|
||||
if OVERRIDE_HOST:
|
||||
set_host_prop("host", OVERRIDE_HOST)
|
||||
host, port = cls.get_override(cluster_name)
|
||||
|
||||
if OVERRIDE_PORT:
|
||||
set_host_prop("port", OVERRIDE_PORT)
|
||||
if host:
|
||||
set_host_prop("host", host)
|
||||
|
||||
if port:
|
||||
set_host_prop("port", port)
|
||||
|
||||
return cluster_config
|
||||
|
||||
@@ -120,7 +132,9 @@ class ESFactory:
|
||||
@classmethod
|
||||
def get_es_timestamp_str(cls):
|
||||
now = datetime.utcnow()
|
||||
return now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
|
||||
return (
|
||||
now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
|
||||
)
|
||||
|
||||
|
||||
es_factory = ESFactory()
|
||||
|
||||
@@ -4,19 +4,20 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from mongoengine.connection import get_db
|
||||
from semantic_version import Version
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from apiserver.database import utils
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.model.version import Version as DatabaseVersion
|
||||
|
||||
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
|
||||
_migrations = "migrations"
|
||||
_parent_dir = Path(__file__).resolve().parents[1]
|
||||
_migration_dir = _parent_dir / _migrations
|
||||
|
||||
|
||||
def check_mongo_empty() -> bool:
|
||||
return not all(
|
||||
get_db(alias).collection_names()
|
||||
for alias in utils.get_options(Database)
|
||||
get_db(alias).collection_names() for alias in utils.get_options(Database)
|
||||
)
|
||||
|
||||
|
||||
@@ -41,8 +42,8 @@ def _apply_migrations(log: Logger):
|
||||
|
||||
log.info(f"Started mongodb migrations")
|
||||
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
if not _migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {_migration_dir}")
|
||||
|
||||
empty_dbs = check_mongo_empty()
|
||||
last_version = get_last_server_version()
|
||||
@@ -50,7 +51,10 @@ def _apply_migrations(log: Logger):
|
||||
try:
|
||||
new_scripts = {
|
||||
ver: path
|
||||
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
|
||||
for ver, path in (
|
||||
(parse(f.stem.replace("_", ".")), f)
|
||||
for f in _migration_dir.glob("*.py")
|
||||
)
|
||||
if ver > last_version
|
||||
}
|
||||
except ValueError as ex:
|
||||
@@ -64,7 +68,10 @@ def _apply_migrations(log: Logger):
|
||||
if empty_dbs:
|
||||
log.info(f"Skipping migration {script.name} (empty databases)")
|
||||
else:
|
||||
spec = importlib.util.spec_from_file_location(script.stem, str(script))
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
".".join(("apiserver", _parent_dir.name, _migrations, script.stem)),
|
||||
str(script),
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
@@ -83,7 +90,7 @@ def _apply_migrations(log: Logger):
|
||||
|
||||
DatabaseVersion(
|
||||
id=utils.id(),
|
||||
num=script.stem,
|
||||
num=str(script_version),
|
||||
created=datetime.utcnow(),
|
||||
desc="Applied on server startup",
|
||||
).save()
|
||||
|
||||
@@ -25,7 +25,6 @@ from typing import (
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import dpath
|
||||
import mongoengine
|
||||
from boltons.iterutils import chunked_iter, first
|
||||
from furl import furl
|
||||
@@ -33,6 +32,7 @@ from mongoengine import Q
|
||||
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.bll.task.artifacts import get_artifact_id
|
||||
from apiserver.bll.task.param_utils import (
|
||||
split_param_name,
|
||||
@@ -44,11 +44,16 @@ from apiserver.config.info import get_default_company
|
||||
from apiserver.database.model import EntityVisibility, User
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
ArtifactModes,
|
||||
TaskStatus,
|
||||
TaskModelTypes,
|
||||
TaskModelNames,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@@ -343,31 +348,10 @@ class PrePopulate:
|
||||
|
||||
return upadated
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict):
|
||||
for old_param_field, new_param_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy = safe_get(task_data, old_param_field)
|
||||
if not legacy:
|
||||
continue
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_param_field, section, name)))
|
||||
if not safe_get(task_data, new_path):
|
||||
new_param = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(task_data, new_path, new_param)
|
||||
dpath.delete(task_data, old_param_field)
|
||||
|
||||
@classmethod
|
||||
def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
|
||||
"""
|
||||
Build content array that contains fixed tasks from the passed file
|
||||
Build content array that contains upgraded tasks from the passed file
|
||||
For each task the old execution.parameters and model.design are
|
||||
converted to the new structure.
|
||||
The fix is done on Task objects (not the dictionary) so that
|
||||
@@ -423,6 +407,24 @@ class PrePopulate:
|
||||
items.append(results[0])
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def _check_projects_hierarchy(cls, projects: Set[Project]):
|
||||
"""
|
||||
For any exported project all its parents up to the root should be present
|
||||
"""
|
||||
if not projects:
|
||||
return
|
||||
|
||||
project_ids = {p.id for p in projects}
|
||||
orphans = [p.id for p in projects if p.parent and p.parent not in project_ids]
|
||||
if not orphans:
|
||||
return
|
||||
|
||||
print(
|
||||
f"ERROR: the following projects are exported without their parents: {orphans}"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
@classmethod
|
||||
def _resolve_entities(
|
||||
cls,
|
||||
@@ -434,6 +436,7 @@ class PrePopulate:
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
projects = project_ids_with_children(projects)
|
||||
entities[cls.project_cls].update(
|
||||
cls._resolve_type(cls.project_cls, projects)
|
||||
)
|
||||
@@ -463,12 +466,16 @@ class PrePopulate:
|
||||
project_ids = {p.id for p in entities[cls.project_cls]}
|
||||
entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
cls._check_projects_hierarchy(entities[cls.project_cls])
|
||||
|
||||
task_models = chain.from_iterable(
|
||||
models
|
||||
for task in entities[cls.task_cls]
|
||||
for model_id in (task.output.model, task.execution.model)
|
||||
if model_id
|
||||
}
|
||||
if task.models
|
||||
for models in (task.models.input, task.models.output)
|
||||
if models
|
||||
)
|
||||
model_ids = {tm.model for tm in task_models}
|
||||
if model_ids:
|
||||
print("Reading models...")
|
||||
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
|
||||
@@ -634,11 +641,12 @@ class PrePopulate:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
Always do the export on sorted items since the order of items influence hash
|
||||
The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom
|
||||
"""
|
||||
artifacts = []
|
||||
now = datetime.utcnow()
|
||||
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||
items = sorted(entities[cls_], key=attrgetter("name", "id"))
|
||||
if not items:
|
||||
continue
|
||||
base_filename = cls._get_base_filename(cls_)
|
||||
@@ -735,6 +743,77 @@ class PrePopulate:
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict) -> dict:
|
||||
"""
|
||||
Migrate from execution/parameters and model_desc to hyperparams and configuration fiields
|
||||
Upgrade artifacts list to dict
|
||||
Migrate from execution.model and output.model to the new models field
|
||||
Move docker_cmd contents into the container field
|
||||
:param task_data: Upgraded in place
|
||||
:return: The upgraded task data
|
||||
"""
|
||||
for old_param_field, new_param_field, default_section in (
|
||||
("execution.parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution.model_desc", "configuration", None),
|
||||
):
|
||||
legacy_path = old_param_field.split(".")
|
||||
legacy = nested_get(task_data, legacy_path)
|
||||
if legacy:
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_param_field, section, name)))
|
||||
if not nested_get(task_data, new_path):
|
||||
new_param = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
nested_set(task_data, path=new_path, value=new_param)
|
||||
nested_delete(task_data, legacy_path)
|
||||
|
||||
artifacts_path = ("execution", "artifacts")
|
||||
artifacts = nested_get(task_data, artifacts_path)
|
||||
if isinstance(artifacts, list):
|
||||
nested_set(
|
||||
task_data,
|
||||
path=artifacts_path,
|
||||
value={get_artifact_id(a): a for a in artifacts},
|
||||
)
|
||||
|
||||
models = task_data.get("models", {})
|
||||
now = datetime.utcnow()
|
||||
for old_field, type_ in (
|
||||
("execution.model", TaskModelTypes.input),
|
||||
("output.model", TaskModelTypes.output),
|
||||
):
|
||||
old_path = old_field.split(".")
|
||||
old_model = nested_get(task_data, old_path)
|
||||
new_models = models.get(type_, [])
|
||||
name = TaskModelNames[type_]
|
||||
if old_model and not any(
|
||||
m
|
||||
for m in new_models
|
||||
if m.get("model") == old_model or m.get("name") == name
|
||||
):
|
||||
model_item = {"model": old_model, "name": name, "updated": now}
|
||||
if type_ == TaskModelTypes.input:
|
||||
new_models = [model_item, *new_models]
|
||||
else:
|
||||
new_models = [*new_models, model_item]
|
||||
models[type_] = new_models
|
||||
nested_delete(task_data, old_path)
|
||||
task_data["models"] = models
|
||||
|
||||
docker_cmd_path = ("execution", "docker_cmd")
|
||||
docker_cmd = nested_get(task_data, docker_cmd_path)
|
||||
if docker_cmd and not task_data.get("container"):
|
||||
image, _, arguments = docker_cmd.partition(" ")
|
||||
task_data["container"] = {"image": image, "arguments": arguments}
|
||||
nested_delete(task_data, docker_cmd_path)
|
||||
|
||||
return task_data
|
||||
|
||||
@classmethod
|
||||
def _import_entity(
|
||||
cls,
|
||||
@@ -750,16 +829,7 @@ class PrePopulate:
|
||||
override_project_count = 0
|
||||
for item in cls.json_lines(f):
|
||||
if cls_ == cls.task_cls:
|
||||
task_data = json.loads(item)
|
||||
artifacts_path = ("execution", "artifacts")
|
||||
artifacts = nested_get(task_data, artifacts_path)
|
||||
if isinstance(artifacts, list):
|
||||
nested_set(
|
||||
task_data,
|
||||
artifacts_path,
|
||||
value={get_artifact_id(a): a for a in artifacts},
|
||||
)
|
||||
item = json.dumps(task_data)
|
||||
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
|
||||
@@ -5,8 +5,7 @@ from pymongo.database import Database, Collection
|
||||
|
||||
def migrate_auth(db: Database):
|
||||
collection: Collection = db["user"]
|
||||
if "name_1_company_1" in [doc["name"] for doc in collection.list_indexes()]:
|
||||
collection.drop_index("name_1_company_1")
|
||||
collection.drop_indexes()
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
@@ -31,8 +31,8 @@ def migrate_auth(db: Database):
|
||||
if not uuids:
|
||||
return
|
||||
|
||||
collection = db["user"]
|
||||
collection.drop_index("name_1_company_1")
|
||||
collection: Collection = db["user"]
|
||||
collection.drop_indexes()
|
||||
|
||||
_switch_uuid(collection=collection, uuid_field="_id", uuids=uuids)
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
from collections import Collection
|
||||
from typing import Sequence
|
||||
from pymongo.database import Database
|
||||
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
|
||||
for collection_name in db.list_collection_names():
|
||||
if collection_name not in names:
|
||||
continue
|
||||
collection: Collection = db[collection_name]
|
||||
collection.drop_indexes()
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def migrate_auth(db: Database):
|
||||
145
apiserver/mongo/migrations/1_0_0.py
Normal file
145
apiserver/mongo/migrations/1_0_0.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
from apiserver.services.utils import escape_dict
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def _migrate_task_models(db: Database):
|
||||
"""
|
||||
Move the execution and output models to new models.input and output lists
|
||||
"""
|
||||
tasks: Collection = db["task"]
|
||||
|
||||
models_field = "models"
|
||||
now = datetime.utcnow()
|
||||
|
||||
fields = {
|
||||
TaskModelTypes.input: "execution.model",
|
||||
TaskModelTypes.output: "output.model",
|
||||
}
|
||||
query = {"$or": [{field: {"$exists": True}} for field in fields.values()]}
|
||||
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
|
||||
set_commands = {}
|
||||
for mode, field in fields.items():
|
||||
value = nested_get(doc, field.split("."))
|
||||
if value:
|
||||
name = TaskModelNames[mode]
|
||||
model_item = {"model": value, "name": name, "updated": now}
|
||||
existing_models = nested_get(doc, (models_field, mode), default=[])
|
||||
existing_models = (
|
||||
m
|
||||
for m in existing_models
|
||||
if m.get("name") != name and m.get("model") != value
|
||||
)
|
||||
if mode == TaskModelTypes.input:
|
||||
updated_models = [model_item, *existing_models]
|
||||
else:
|
||||
updated_models = [*existing_models, model_item]
|
||||
set_commands[f"{models_field}.{mode}"] = updated_models
|
||||
|
||||
tasks.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{
|
||||
"$unset": {field: 1 for field in fields.values()},
|
||||
**({"$set": set_commands} if set_commands else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _migrate_docker_cmd(db: Database):
|
||||
tasks: Collection = db["task"]
|
||||
|
||||
docker_cmd_field = "execution.docker_cmd"
|
||||
query = {docker_cmd_field: {"$exists": True}}
|
||||
|
||||
for doc in tasks.find(filter=query, projection=(docker_cmd_field,)):
|
||||
set_commands = {}
|
||||
docker_cmd = nested_get(doc, docker_cmd_field.split("."))
|
||||
if docker_cmd:
|
||||
image, _, arguments = docker_cmd.partition(" ")
|
||||
set_commands["container"] = {"image": image, "arguments": arguments}
|
||||
|
||||
tasks.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{
|
||||
"$unset": {docker_cmd_field: 1},
|
||||
**({"$set": set_commands} if set_commands else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _migrate_model_labels(db: Database):
|
||||
tasks: Collection = db["task"]
|
||||
|
||||
fields = ("execution.model_labels", "container")
|
||||
query = {"$or": [{field: {"$nin": [None, {}]}} for field in fields]}
|
||||
|
||||
for doc in tasks.find(filter=query, projection=fields):
|
||||
set_commands = {}
|
||||
for field in fields:
|
||||
data = nested_get(doc, field.split("."))
|
||||
if not data:
|
||||
continue
|
||||
escaped = escape_dict(data)
|
||||
if data == escaped:
|
||||
continue
|
||||
set_commands[field] = escaped
|
||||
|
||||
if set_commands:
|
||||
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
|
||||
|
||||
|
||||
def _migrate_project_description(db: Database):
|
||||
projects: Collection = db["project"]
|
||||
filter = {
|
||||
"$or": [
|
||||
{
|
||||
"$expr": {"$lt": [{"$strLenCP": "$description"}, 100]},
|
||||
"description": {"$regex": "^Auto-generated at ", "$options": "i"},
|
||||
},
|
||||
{"description": {"$regex": "^Auto-generated during move$", "$options": "i"}},
|
||||
{"description": {"$regex": "^Auto-generated while cloning$", "$options": "i"}},
|
||||
]
|
||||
}
|
||||
for doc in projects.find(filter=filter):
|
||||
projects.update_one({"_id": doc["_id"]}, {"$unset": {"description": 1}})
|
||||
|
||||
|
||||
def _migrate_project_names(db: Database):
|
||||
projects: Collection = db["project"]
|
||||
|
||||
regx = re.compile("/", re.IGNORECASE)
|
||||
for doc in projects.find(filter={"name": regx, "path": {"$in": [None, []]}}):
|
||||
name = doc.get("name")
|
||||
if not name:
|
||||
continue
|
||||
|
||||
max_tries = int(os.getenv("CLEARML_MIGRATION_PROJECT_RENAME_MAX_TRIES", 10))
|
||||
iteration = 0
|
||||
for iteration in range(max_tries):
|
||||
new_name = name.replace("/", "_" * (iteration + 1))
|
||||
try:
|
||||
projects.update_one({"_id": doc["_id"]}, {"$set": {"name": new_name}})
|
||||
break
|
||||
except DuplicateKeyError:
|
||||
pass
|
||||
|
||||
if iteration >= max_tries - 1:
|
||||
print(f"Could not upgrade the name {name} of the project {doc.get('_id')}")
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
_migrate_task_models(db)
|
||||
_migrate_docker_cmd(db)
|
||||
_migrate_model_labels(db)
|
||||
_migrate_project_names(db)
|
||||
_migrate_project_description(db)
|
||||
_drop_all_indices_from_collections(db, ["task*"])
|
||||
20
apiserver/mongo/migrations/utils.py
Normal file
20
apiserver/mongo/migrations/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Sequence
|
||||
|
||||
from boltons.iterutils import partition
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
|
||||
"""
|
||||
Drop all indices for the existing collections from the specified list
|
||||
"""
|
||||
prefixes, names = partition(names, key=lambda x: x.endswith("*"))
|
||||
prefixes = {p.rstrip("*") for p in prefixes}
|
||||
for collection_name in db.list_collection_names():
|
||||
if not (
|
||||
collection_name in names
|
||||
or any(p for p in prefixes if collection_name.startswith(p))
|
||||
):
|
||||
continue
|
||||
collection: Collection = db[collection_name]
|
||||
collection.drop_indexes()
|
||||
@@ -11,8 +11,16 @@ from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = ("TRAINS_REDIS_SERVICE_HOST", "REDIS_SERVICE_HOST")
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_REDIS_SERVICE_PORT", "REDIS_SERVICE_PORT")
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_HOST",
|
||||
"TRAINS_REDIS_SERVICE_HOST",
|
||||
"REDIS_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_PORT",
|
||||
"TRAINS_REDIS_SERVICE_PORT",
|
||||
"REDIS_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
attrs>=19.1.0
|
||||
bcrypt>=3.1.4
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
@@ -11,12 +12,13 @@ funcsigs==1.0.2
|
||||
furl>=2.0.0
|
||||
gunicorn>=19.7.1
|
||||
humanfriendly==4.18
|
||||
jinja2==2.10
|
||||
jinja2==2.11.3
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.10.0
|
||||
mongoengine==0.19.1
|
||||
nested_dict>=1.61
|
||||
packaging==20.3
|
||||
psutil>=5.6.5
|
||||
pyhocon>=0.3.35
|
||||
pyjwt<2.0.0
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
metadata_item {
|
||||
type: object
|
||||
properties {
|
||||
key {
|
||||
type: string
|
||||
description: The key uniquely identifying the metadata item inside the given entity
|
||||
}
|
||||
type {
|
||||
type: string
|
||||
description: The type of the metadata item
|
||||
}
|
||||
value {
|
||||
type: string
|
||||
description: The value stored in the metadata item
|
||||
}
|
||||
}
|
||||
}
|
||||
credentials {
|
||||
type: object
|
||||
properties {
|
||||
@@ -11,3 +28,61 @@ credentials {
|
||||
}
|
||||
}
|
||||
}
|
||||
batch_operation {
|
||||
request {
|
||||
type: object
|
||||
required: [ids]
|
||||
properties {
|
||||
ids {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
succeeded {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
id: {
|
||||
description: ID of the succeeded entity
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
failed {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
id: {
|
||||
description: ID of the failed entity
|
||||
type: string
|
||||
}
|
||||
error: {
|
||||
description: Error info
|
||||
type: object
|
||||
properties {
|
||||
codes {
|
||||
type: array
|
||||
items {type: integer}
|
||||
}
|
||||
msg {
|
||||
type: string
|
||||
}
|
||||
data {
|
||||
type: object
|
||||
additionalProperties: True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -187,16 +187,12 @@
|
||||
}
|
||||
task_metric {
|
||||
type: object
|
||||
required: [task, metric]
|
||||
required: [task]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
task_log_event {
|
||||
@@ -370,7 +366,7 @@
|
||||
}
|
||||
}
|
||||
"2.7" {
|
||||
description: "Get the debug image events for the requested amount of iterations per each task's metric"
|
||||
description: "Get the debug image events for the requested amount of iterations per each task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
|
||||
@@ -85,7 +85,27 @@ supported_modes {
|
||||
}
|
||||
}
|
||||
}
|
||||
authenticated {
|
||||
description: "Is user authenticated"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logout {
|
||||
authorize: false
|
||||
allow_roles = [ "*" ]
|
||||
"2.13" {
|
||||
description: """ Logout (including SSO, if used)) """
|
||||
request {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
@@ -38,6 +39,11 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Model last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
task {
|
||||
description: "Task ID of task in which the model was created"
|
||||
type: string
|
||||
@@ -91,6 +97,37 @@ _definitions {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
published_task_item {
|
||||
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Task id"
|
||||
type: string
|
||||
}
|
||||
data {
|
||||
description: "Data returned from the task publishing operation."
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -151,6 +188,17 @@ get_by_id_ex {
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"}
|
||||
"2.13": ${get_all_ex."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -433,6 +481,13 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit {
|
||||
"2.1" {
|
||||
@@ -521,6 +576,13 @@ edit {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@@ -597,6 +659,40 @@ update {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${update."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
publish_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Publish models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to publish"
|
||||
force_publish_task {
|
||||
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
|
||||
type: boolean
|
||||
}
|
||||
publish_tasks {
|
||||
description: "Indicates that the associated tasks (if exist) should be published. Optional, the default value is True."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.updated {
|
||||
description: "Indicates whether the model was updated"
|
||||
type: boolean
|
||||
}
|
||||
succeeded.items.properties.published_task: ${_definitions.published_task_item}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
set_ready {
|
||||
"2.1" {
|
||||
@@ -627,39 +723,69 @@ set_ready {
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
published_task {
|
||||
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Task id"
|
||||
type: string
|
||||
}
|
||||
data {
|
||||
description: "Data returned from the task publishing operation."
|
||||
type: object
|
||||
properties {
|
||||
committed_versions_results {
|
||||
description: "Committed versions results"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
published_task: ${_definitions.published_task_item}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
archive_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Archive models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to archive"
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.archived {
|
||||
description: "Indicates whether the model was archived"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
unarchive_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Unarchive models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of the models to unarchive"
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.unarchived {
|
||||
description: "Indicates whether the model was unarchived"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Delete models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to delete"
|
||||
force {
|
||||
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
|
||||
"""
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.deleted {
|
||||
description: "Indicates whether the model was deleted"
|
||||
type: boolean
|
||||
}
|
||||
succeeded.items.properties.url {
|
||||
description: "The url of the model file"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -697,6 +823,16 @@ delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${delete."2.1"} {
|
||||
response {
|
||||
properties {
|
||||
url {
|
||||
description: "The url of the model file"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
@@ -777,4 +913,63 @@ move {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
add_or_update_metadata {
|
||||
"2.13" {
|
||||
description: "Add or update model metadata"
|
||||
request {
|
||||
type: object
|
||||
required: [model, metadata]
|
||||
properties {
|
||||
model {
|
||||
description: "ID of the model"
|
||||
type: string
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_metadata {
|
||||
"2.13" {
|
||||
description: "Delete metadata from model"
|
||||
request {
|
||||
type: object
|
||||
required: [ model, keys ]
|
||||
properties {
|
||||
model {
|
||||
description: "ID of the model"
|
||||
type: string
|
||||
}
|
||||
keys {
|
||||
description: "The list of metadata keys to delete"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -167,10 +167,27 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
// extra properties
|
||||
stats: {
|
||||
stats {
|
||||
description: "Additional project stats"
|
||||
"$ref": "#/definitions/stats"
|
||||
}
|
||||
sub_projects {
|
||||
description: "The list of sub projects"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Subproject ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Subproject name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
metric_variant_result {
|
||||
@@ -242,6 +259,23 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
urls {
|
||||
type: object
|
||||
properties {
|
||||
model_urls {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
event_urls {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
artifact_urls {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
@@ -249,17 +283,14 @@ create {
|
||||
description: "Create a new project"
|
||||
request {
|
||||
type: object
|
||||
required :[
|
||||
name
|
||||
description
|
||||
]
|
||||
required :[name]
|
||||
properties {
|
||||
name {
|
||||
description: "Project name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
description: "Project description."
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
@@ -388,6 +419,17 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
shallow_search {
|
||||
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
@@ -413,6 +455,39 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_all_ex."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
active_users {
|
||||
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
shallow_search {
|
||||
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
check_own_contents {
|
||||
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
own_tasks {
|
||||
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
|
||||
type: integer
|
||||
}
|
||||
own_models {
|
||||
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@@ -470,6 +545,66 @@ update {
|
||||
}
|
||||
}
|
||||
}
|
||||
move {
|
||||
"2.13" {
|
||||
description: "Moves a project and all of its subprojects under the different location"
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
new_location {
|
||||
description: "The name location for the project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
moved {
|
||||
description: "The number of projects moved"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
merge {
|
||||
"2.13" {
|
||||
description: "Moves all the source project's contents to the destination project and remove the source project"
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
destination_project {
|
||||
description: "The ID of the destination project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
moved_entities {
|
||||
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
|
||||
type: integer
|
||||
}
|
||||
moved_projects {
|
||||
description: "The number of child projects moved from the merged project into the destination"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
@@ -504,6 +639,32 @@ delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${delete."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
delete_contents {
|
||||
description: "If set to 'true' then the project tasks and models will be deleted. Otherwise their project property will be unassigned. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
urls {
|
||||
description: "The urls of the files that were uploaded by the project tasks and models. Returned if the 'delete_contents' was set to 'true'"
|
||||
"$ref": "#/definitions/urls"
|
||||
}
|
||||
deleted_models {
|
||||
description: "Number of models deleted"
|
||||
type: integer
|
||||
}
|
||||
deleted_tasks {
|
||||
description: "Number of tasks deleted"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_unique_metric_variants {
|
||||
"2.1" {
|
||||
@@ -530,6 +691,64 @@ get_unique_metric_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_unique_metric_variants."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metrics/variants from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyperparam_values {
|
||||
"2.13" {
|
||||
description: """Get a list of distinct values for the chosen hyperparameter"""
|
||||
request {
|
||||
type: object
|
||||
required: [section, name]
|
||||
properties {
|
||||
projects {
|
||||
description: "Project IDs"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
section {
|
||||
description: "Hyperparameter section name"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Hyperparameter name"
|
||||
type: string
|
||||
}
|
||||
allow_public {
|
||||
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes hyper parameters values from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
total {
|
||||
description: "Total number of distinct parameter values"
|
||||
type: integer
|
||||
}
|
||||
values {
|
||||
description: "The list of the unique values for the parameter"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyper_parameters {
|
||||
"2.9" {
|
||||
@@ -572,6 +791,17 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_hyper_parameters."2.9"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes hyper parameters from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_task_tags {
|
||||
@@ -641,7 +871,7 @@ make_private {
|
||||
}
|
||||
get_task_parents {
|
||||
"2.12" {
|
||||
description: "Get unique parent tasks for the tasks in the specified pprojects"
|
||||
description: "Get unique parent tasks for the tasks in the specified projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
@@ -692,4 +922,15 @@ get_task_parents {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_task_parents."2.12"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the projects field is not empty then the result includes tasks parents from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,6 +26,49 @@ _references {
|
||||
}
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
change_many_request: ${_definitions.batch_operation} {
|
||||
request {
|
||||
properties {
|
||||
status_reason {
|
||||
description: Reason for status change
|
||||
type: string
|
||||
}
|
||||
status_message {
|
||||
description: Extra information regarding status change
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
succeeded.items.properties.fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update_response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
@@ -40,6 +83,24 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
model_type_enum {
|
||||
type: string
|
||||
enum: ["input", "output"]
|
||||
}
|
||||
task_model_item {
|
||||
type: object
|
||||
required: [ name, model]
|
||||
properties {
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "The model ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
script {
|
||||
type: object
|
||||
properties {
|
||||
@@ -207,6 +268,22 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
task_models {
|
||||
type: object
|
||||
properties {
|
||||
input {
|
||||
description: "The list of task input models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
|
||||
}
|
||||
output {
|
||||
description: "The list of task output models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
execution {
|
||||
type: object
|
||||
properties {
|
||||
@@ -454,6 +531,15 @@ _definitions {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
// TODO: will be removed
|
||||
script {
|
||||
description: "Script info"
|
||||
@@ -531,6 +617,28 @@ _definitions {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
task_urls {
|
||||
type: object
|
||||
properties {
|
||||
model_urls {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
event_urls {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
artifact_urls {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -566,6 +674,17 @@ get_by_id_ex {
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"}
|
||||
"2.13": ${get_all_ex."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -805,6 +924,106 @@ clone {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${clone."2.12"}{
|
||||
request {
|
||||
properties {
|
||||
new_task_input_models {
|
||||
description: "The list of input models for the cloned task. If not specifed then copied from the original task"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
}
|
||||
new_task_container {
|
||||
description: "The docker container properties for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
add_or_update_model {
|
||||
"2.13" {
|
||||
description: "Add or update task model"
|
||||
request {
|
||||
type: object
|
||||
required: [task, name, model, type]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "The model ID"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "The task model type"
|
||||
"$ref": "#/definitions/model_type_enum"
|
||||
}
|
||||
iteration {
|
||||
description: "Iteration (used to update task statistics)"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_models {
|
||||
"2.13" {
|
||||
description: "Delete models from task"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, models ]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task"
|
||||
type: string
|
||||
}
|
||||
models {
|
||||
description: "The list of models to delete"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
required: [name, type]
|
||||
properties {
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "The task model type"
|
||||
"$ref": "#/definitions/model_type_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
create {
|
||||
"2.1" {
|
||||
@@ -884,6 +1103,21 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
validate {
|
||||
"2.1" {
|
||||
@@ -958,6 +1192,21 @@ validate {
|
||||
additionalProperties: false
|
||||
}
|
||||
}
|
||||
"2.13": ${validate."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@@ -1005,21 +1254,7 @@ update {
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
update_batch {
|
||||
@@ -1117,16 +1352,22 @@ edit {
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
@@ -1151,24 +1392,13 @@ reset {
|
||||
default: false
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
response: ${_definitions.update_response} {
|
||||
properties {
|
||||
deleted_indices {
|
||||
description: "List of deleted ES indices that were removed as part of the reset process"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
dequeued {
|
||||
description: "Response from queues.remove_task"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
frames {
|
||||
description: "Response from frames.rollback"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
events {
|
||||
description: "Response from events.delete_for_task"
|
||||
type: object
|
||||
@@ -1178,16 +1408,126 @@ reset {
|
||||
description: "Number of output models deleted by the reset"
|
||||
type: integer
|
||||
}
|
||||
updated {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${reset."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
return_file_urls {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
urls {
|
||||
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'"
|
||||
"$ref": "#/definitions/task_urls"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
reset_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Reset tasks
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of the tasks to reset"
|
||||
force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is 'completed'"
|
||||
}
|
||||
clear_all {
|
||||
description: "Clear script and execution sections completely"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
return_file_urls {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
delete_output_models {
|
||||
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.dequeued {
|
||||
description: "Indicates whether the task was dequeued"
|
||||
type: boolean
|
||||
}
|
||||
succeeded.items.properties.updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
succeeded.items.properties.fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
succeeded.items.properties.deleted_models {
|
||||
description: "Number of output models deleted by the reset"
|
||||
type: integer
|
||||
}
|
||||
succeeded.items.properties.urls {
|
||||
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
|
||||
"$ref": "#/definitions/task_urls"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Delete tasks
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of the tasks to delete"
|
||||
move_to_trash {
|
||||
description: "Move task to trash instead of deleting it. For internal use only, tasks in the trash are not visible from the API and cannot be restored!"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is 'in_progress'"
|
||||
}
|
||||
return_file_urls {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
delete_output_models {
|
||||
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.deleted {
|
||||
description: "Indicates whether the task was deleted"
|
||||
type: boolean
|
||||
}
|
||||
succeeded.items.properties.updated_children {
|
||||
description: "Number of child tasks whose parent property was updated"
|
||||
type: integer
|
||||
}
|
||||
succeeded.items.properties.updated_models {
|
||||
description: "Number of models whose task property was updated"
|
||||
type: integer
|
||||
}
|
||||
succeeded.items.properties.deleted_models {
|
||||
description: "Number of deleted output models"
|
||||
type: integer
|
||||
}
|
||||
succeeded.items.properties.urls {
|
||||
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
|
||||
"$ref": "#/definitions/task_urls"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1229,15 +1569,6 @@ delete {
|
||||
description: "Number of models whose task property was updated"
|
||||
type: integer
|
||||
}
|
||||
updated_versions {
|
||||
description: "Number of dataset versions whose task property was updated"
|
||||
type: integer
|
||||
}
|
||||
frames {
|
||||
description: "Response from frames.rollback"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
events {
|
||||
description: "Response from events.delete_for_task"
|
||||
type: object
|
||||
@@ -1246,6 +1577,24 @@ delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${delete."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
return_file_urls {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
urls {
|
||||
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'"
|
||||
"$ref": "#/definitions/task_urls"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
archive {
|
||||
"2.12" {
|
||||
@@ -1284,6 +1633,58 @@ archive {
|
||||
}
|
||||
}
|
||||
}
|
||||
archive_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Archive tasks
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of the tasks to archive"
|
||||
status_reason {
|
||||
description: Reason for status change
|
||||
type: string
|
||||
}
|
||||
status_message {
|
||||
description: Extra information regarding status change
|
||||
type: string
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.archived {
|
||||
description: "Indicates whether the task was archived"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
unarchive_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
description: Unarchive tasks
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of the tasks to unarchive"
|
||||
status_reason {
|
||||
description: Reason for status change
|
||||
type: string
|
||||
}
|
||||
status_message {
|
||||
description: Extra information regarding status change
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.unarchived {
|
||||
description: "Indicates whether the task was unarchived"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
started {
|
||||
"2.1" {
|
||||
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
|
||||
@@ -1296,24 +1697,13 @@ started {
|
||||
description: "If not true, call fails if the task status is not 'not_started'"
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
response: ${_definitions.update_response} {
|
||||
properties {
|
||||
started {
|
||||
description: "Number of tasks started (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1330,18 +1720,17 @@ stop {
|
||||
description: "If not true, call fails if the task status is not 'in_progress'"
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
stop_many {
|
||||
"2.13": ${_definitions.change_many_request} {
|
||||
description: "Request to stop running tasks"
|
||||
request {
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
ids.description: "IDs of the tasks to stop"
|
||||
force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is not 'in_progress'"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1359,21 +1748,7 @@ stopped {
|
||||
description: "If not true, call fails if the task status is not 'stopped'"
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
failed {
|
||||
@@ -1386,21 +1761,7 @@ failed {
|
||||
]
|
||||
properties.force = ${_references.force_arg}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
close {
|
||||
@@ -1413,21 +1774,7 @@ close {
|
||||
]
|
||||
properties.force = ${_references.force_arg}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
publish {
|
||||
@@ -1452,26 +1799,21 @@ publish {
|
||||
}
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
publish_many {
|
||||
"2.13": ${_definitions.change_many_request} {
|
||||
description: Publish tasks
|
||||
request {
|
||||
properties {
|
||||
committed_versions_results {
|
||||
description: "Committed versions results"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
ids.description: "IDs of the tasks to publish"
|
||||
force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is not 'stopped'"
|
||||
}
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
publish_model {
|
||||
description: "Indicates that the task output model (if exists) should be published. Optional, the default value is True."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1502,23 +1844,39 @@ Fails if the following parameters in the task were not filled:
|
||||
}
|
||||
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
response: ${_definitions.update_response} {
|
||||
properties {
|
||||
queued {
|
||||
description: "Number of tasks queued (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
enqueue_many {
|
||||
"2.13": ${_definitions.change_many_request} {
|
||||
description: Enqueue tasks
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of the tasks to enqueue"
|
||||
queue {
|
||||
description: "Queue id. If not provided, tasks are added to the default queue."
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
validate_tasks {
|
||||
description: "If set then tasks are validated before enqueue"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.queued {
|
||||
description: "Indicates whether the task was queued"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1534,23 +1892,30 @@ dequeue {
|
||||
task
|
||||
]
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
response: ${_definitions.update_response} {
|
||||
properties {
|
||||
dequeued {
|
||||
description: "Number of tasks dequeued (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dequeue_many {
|
||||
"2.13": ${_definitions.change_many_request} {
|
||||
description: Dequeue tasks
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of the tasks to dequeue"
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.dequeued {
|
||||
description: "Indicates whether the task was dequeued"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1576,21 +1941,7 @@ set_requirements {
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1606,21 +1957,7 @@ completed {
|
||||
description: "If not true, call fails if the task status is not in_progress/stopped"
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1932,6 +2269,11 @@ get_configuration_names {
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
skip_empty {
|
||||
description: If set to 'true' then the names for configurations with missing values are not returned
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from functools import partial
|
||||
|
||||
from flask import request, Response, redirect
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import ServiceRepo, APICall
|
||||
@@ -25,7 +28,10 @@ class RequestHandlers:
|
||||
|
||||
try:
|
||||
call = self._create_api_call(request)
|
||||
content, content_type = ServiceRepo.handle_call(call)
|
||||
load_data_callback = partial(self._load_call_data, req=request)
|
||||
content, content_type = ServiceRepo.handle_call(
|
||||
call, load_data_callback=load_data_callback
|
||||
)
|
||||
|
||||
if call.result.redirect:
|
||||
response = redirect(call.result.redirect.url, call.result.redirect.code)
|
||||
@@ -137,9 +143,6 @@ class RequestHandlers:
|
||||
auth_cookie=auth_cookie,
|
||||
)
|
||||
|
||||
# Update call data from request
|
||||
self._update_call_data(call, req)
|
||||
|
||||
except PathParsingError as ex:
|
||||
call = self._call_or_empty_with_error(call, req, ex.args[0], 400)
|
||||
call.log_api = False
|
||||
@@ -156,3 +159,18 @@ class RequestHandlers:
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def _load_call_data(self, call: APICall, req):
|
||||
"""Update call data from request"""
|
||||
try:
|
||||
self._update_call_data(call, req)
|
||||
except BadRequest as ex:
|
||||
call.set_error_result(msg=ex.description, code=400)
|
||||
except BaseError as ex:
|
||||
call.set_error_result(msg=ex.msg, code=ex.code, subcode=ex.subcode)
|
||||
except APIError as ex:
|
||||
call.set_error_result(
|
||||
msg=ex.msg, code=ex.code, subcode=ex.subcode, error_data=ex.error_data
|
||||
)
|
||||
except Exception as ex:
|
||||
call.set_error_result(msg=ex.args[0] if ex.args else type(ex).__name__)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Text, Sequence, Callable, Union, Type
|
||||
|
||||
from funcsigs import signature
|
||||
from inspect import signature
|
||||
from jsonmodels import models
|
||||
|
||||
from .apicall import APICall, APICallResult
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Type, Optional, Union, Tuple
|
||||
|
||||
import attr
|
||||
from jsonmodels import models
|
||||
from requests.structures import CaseInsensitiveDict
|
||||
from six import string_types
|
||||
|
||||
from apiserver import database
|
||||
@@ -313,6 +314,13 @@ class APICall(DataContainer):
|
||||
def HEADER_TRANSACTION(self):
|
||||
return self._transaction_headers[0]
|
||||
|
||||
_client_headers = _get_headers("Client")
|
||||
""" Client """
|
||||
|
||||
@property
|
||||
def HEADER_CLIENT(self):
|
||||
return self._client_headers[0]
|
||||
|
||||
_worker_headers = _get_headers("Worker")
|
||||
""" Worker (machine) ID """
|
||||
|
||||
@@ -366,7 +374,7 @@ class APICall(DataContainer):
|
||||
assert isinstance(endpoint_version, PartialVersion), endpoint_version
|
||||
self._requested_endpoint_version = endpoint_version
|
||||
self._actual_endpoint_version = None
|
||||
self._headers = {}
|
||||
self._headers = CaseInsensitiveDict()
|
||||
self._kpis = {}
|
||||
self._log_api = True
|
||||
if headers:
|
||||
@@ -379,6 +387,7 @@ class APICall(DataContainer):
|
||||
self._requires_authorization = True
|
||||
self._host = host
|
||||
self._auth_cookie = auth_cookie
|
||||
self._json_flags = {}
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
@@ -420,7 +429,7 @@ class APICall(DataContainer):
|
||||
:param header: Header name options (more than on supported, all will be cleared)
|
||||
"""
|
||||
for value in header if isinstance(header, (tuple, list)) else (header,):
|
||||
self.headers.pop(value, None)
|
||||
self._headers.pop(value, None)
|
||||
|
||||
def set_header(self, header, value):
|
||||
"""
|
||||
@@ -514,7 +523,7 @@ class APICall(DataContainer):
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._headers
|
||||
return dict(self._headers.items())
|
||||
|
||||
@property
|
||||
def kpis(self):
|
||||
@@ -532,6 +541,10 @@ class APICall(DataContainer):
|
||||
def trx(self, value):
|
||||
self.set_header(self._transaction_headers, value)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self.get_header(self._client_headers)
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
return self.get_worker(default="<unknown>")
|
||||
@@ -567,6 +580,10 @@ class APICall(DataContainer):
|
||||
def auth_cookie(self):
|
||||
return self._auth_cookie
|
||||
|
||||
@property
|
||||
def json_flags(self):
|
||||
return self._json_flags
|
||||
|
||||
def mark_end(self):
|
||||
self._end_ts = time.time()
|
||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||
@@ -617,7 +634,8 @@ class APICall(DataContainer):
|
||||
}
|
||||
if self.content_type.lower() == JSON_CONTENT_TYPE:
|
||||
try:
|
||||
res = json.dumps(res)
|
||||
func = json.dumps if self._json_flags.pop("ensure_ascii", True) else json.dumps_notascii
|
||||
res = func(res, **(self._json_flags or {}))
|
||||
except Exception as ex:
|
||||
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
|
||||
if not (self.result.data or self.result.error_data):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from mongoengine import Q
|
||||
|
||||
@@ -31,8 +32,8 @@ def get_auth_func(auth_type):
|
||||
|
||||
|
||||
def authorize_token(jwt_token, *_, **__):
|
||||
""" Validate token against service/endpoint and requests data (dicts).
|
||||
Returns a parsed token object (auth payload)
|
||||
"""Validate token against service/endpoint and requests data (dicts).
|
||||
Returns a parsed token object (auth payload)
|
||||
"""
|
||||
try:
|
||||
return Token.from_encoded_token(jwt_token)
|
||||
@@ -51,13 +52,15 @@ def authorize_token(jwt_token, *_, **__):
|
||||
|
||||
|
||||
def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
""" Validate credentials against service/action and request data (dicts).
|
||||
Returns a new basic object (auth payload)
|
||||
"""Validate credentials against service/action and request data (dicts).
|
||||
Returns a new basic object (auth payload)
|
||||
"""
|
||||
try:
|
||||
access_key, _, secret_key = base64.b64decode(auth_data.encode()).decode('latin-1').partition(':')
|
||||
access_key, _, secret_key = (
|
||||
base64.b64decode(auth_data.encode()).decode("latin-1").partition(":")
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception('malformed credentials')
|
||||
log.exception("malformed credentials")
|
||||
raise errors.unauthorized.BadCredentials(str(e))
|
||||
|
||||
query = Q(credentials__match=Credentials(key=access_key, secret=secret_key))
|
||||
@@ -67,18 +70,32 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
if FixedUser.enabled():
|
||||
fixed_user = FixedUser.get_by_username(access_key)
|
||||
if fixed_user:
|
||||
if secret_key != fixed_user.password:
|
||||
raise errors.unauthorized.InvalidCredentials('bad username or password')
|
||||
if FixedUser.pass_hashed():
|
||||
if not compare_secret_key_hash(secret_key, fixed_user.password):
|
||||
raise errors.unauthorized.InvalidCredentials(
|
||||
"bad username or password"
|
||||
)
|
||||
else:
|
||||
if secret_key != fixed_user.password:
|
||||
raise errors.unauthorized.InvalidCredentials(
|
||||
"bad username or password"
|
||||
)
|
||||
|
||||
if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action):
|
||||
raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest')
|
||||
raise errors.unauthorized.InvalidCredentials(
|
||||
"endpoint not allowed for guest"
|
||||
)
|
||||
|
||||
query = Q(id=fixed_user.user_id)
|
||||
|
||||
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
|
||||
with TimingContext("mongo", "user_by_cred"), translate_errors_context(
|
||||
"authorizing request"
|
||||
):
|
||||
user = User.objects(query).first()
|
||||
if not user:
|
||||
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
|
||||
raise errors.unauthorized.InvalidCredentials(
|
||||
"failed to locate provided credentials"
|
||||
)
|
||||
|
||||
if not fixed_user:
|
||||
# In case these are proper credentials, update last used time
|
||||
@@ -87,13 +104,18 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "company_by_id"):
|
||||
company = Company.objects(id=user.company).only('id', 'name').first()
|
||||
company = Company.objects(id=user.company).only("id", "name").first()
|
||||
|
||||
if not company:
|
||||
raise errors.unauthorized.InvalidCredentials('invalid user company')
|
||||
raise errors.unauthorized.InvalidCredentials("invalid user company")
|
||||
|
||||
identity = Identity(user=user.id, company=user.company, role=user.role,
|
||||
user_name=user.name, company_name=company.name)
|
||||
identity = Identity(
|
||||
user=user.id,
|
||||
company=user.company,
|
||||
role=user.role,
|
||||
user_name=user.name,
|
||||
company_name=company.name,
|
||||
)
|
||||
|
||||
basic = Basic(user_key=access_key, identity=identity)
|
||||
|
||||
@@ -110,3 +132,13 @@ def authorize_impersonation(user, identity, service, action, call):
|
||||
raise errors.unauthorized.InvalidCredentials("invalid user company")
|
||||
|
||||
return Payload(auth_type=None, identity=identity)
|
||||
|
||||
|
||||
def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
|
||||
"""
|
||||
Compare hash for the passed secret key with the passed hash
|
||||
:return: True if equal. Otherwise False
|
||||
"""
|
||||
return bcrypt.checkpw(
|
||||
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
|
||||
)
|
||||
|
||||
@@ -32,6 +32,10 @@ class FixedUser:
|
||||
def guest_enabled(cls):
|
||||
return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False)
|
||||
|
||||
@classmethod
|
||||
def pass_hashed(cls):
|
||||
return config.get("apiserver.auth.fixed_users.pass_hashed", False)
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
if not cls.enabled():
|
||||
|
||||
@@ -13,7 +13,12 @@ from .apicall import APICall
|
||||
from .endpoint import Endpoint
|
||||
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
|
||||
from .util import parse_return_stack_on_code
|
||||
from .validators import validate_all
|
||||
from .validators import (
|
||||
validate_data,
|
||||
validate_auth,
|
||||
validate_role,
|
||||
validate_impersonation,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -32,7 +37,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.12")
|
||||
_max_version = PartialVersion("2.13")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
@@ -227,18 +232,6 @@ class ServiceRepo(object):
|
||||
return True
|
||||
return subcode in subcode_list
|
||||
|
||||
@classmethod
|
||||
def _validate_call(cls, call: APICall) -> Optional[Endpoint]:
|
||||
endpoint = cls._resolve_endpoint_from_call(call)
|
||||
if call.failed:
|
||||
return
|
||||
validate_all(call, endpoint)
|
||||
return endpoint
|
||||
|
||||
@classmethod
|
||||
def validate_call(cls, call: APICall):
|
||||
cls._validate_call(call)
|
||||
|
||||
@classmethod
|
||||
def _get_company(
|
||||
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
||||
@@ -252,7 +245,7 @@ class ServiceRepo(object):
|
||||
return call.identity.company
|
||||
|
||||
@classmethod
|
||||
def handle_call(cls, call: APICall):
|
||||
def handle_call(cls, call: APICall, load_data_callback: Callable = None):
|
||||
try:
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
@@ -262,7 +255,18 @@ class ServiceRepo(object):
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
validate_all(call, endpoint)
|
||||
validate_auth(endpoint, call)
|
||||
validate_role(endpoint, call)
|
||||
if validate_impersonation(endpoint, call):
|
||||
# if impersonating, validate role again
|
||||
validate_role(endpoint, call)
|
||||
|
||||
if load_data_callback:
|
||||
load_data_callback(call)
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
validate_data(call, endpoint)
|
||||
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
@@ -14,17 +14,9 @@ from .errors import CallParsingError
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def validate_all(call: APICall, endpoint: Endpoint):
|
||||
def validate_data(call: APICall, endpoint: Endpoint):
|
||||
""" Perform all required call/endpoint validation, update call result appropriately """
|
||||
try:
|
||||
validate_auth(endpoint, call)
|
||||
|
||||
validate_role(endpoint, call)
|
||||
|
||||
if validate_impersonation(endpoint, call):
|
||||
# if impersonating, validate role again
|
||||
validate_role(endpoint, call)
|
||||
|
||||
# todo: remove vaildate_required_fields once all endpoints have json schema
|
||||
validate_required_fields(endpoint, call)
|
||||
|
||||
|
||||
@@ -41,14 +41,12 @@ def login(call: APICall, *_, **__):
|
||||
)
|
||||
|
||||
# Add authorization cookie
|
||||
call.result.cookies[
|
||||
config.get("apiserver.auth.session_auth_cookie_name")
|
||||
] = call.result.data_model.token
|
||||
call.result.set_auth_cookie(call.result.data_model.token)
|
||||
|
||||
|
||||
@endpoint("auth.logout", min_version="2.2")
|
||||
def logout(call: APICall, *_, **__):
|
||||
call.result.cookies[config.get("apiserver.auth.session_auth_cookie_name")] = None
|
||||
call.result.set_auth_cookie(None)
|
||||
|
||||
|
||||
@endpoint(
|
||||
|
||||
@@ -3,4 +3,4 @@ from apiserver.service_repo import APICall, endpoint
|
||||
|
||||
@endpoint("debug.ping")
|
||||
def ping(call: APICall, _, __):
|
||||
call.result.data = {"msg": "Because it trains cats and dogs"}
|
||||
call.result.data = {"msg": "ClearML server"}
|
||||
|
||||
@@ -594,10 +594,16 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
response_data_model=DebugImageResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_ids = {m.task for m in request.metrics}
|
||||
task_metrics = defaultdict(set)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task].add(tm.metric)
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
task_ids=list(task_metrics),
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
)
|
||||
@@ -610,7 +616,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=next(iter(companies)),
|
||||
metrics=[(m.task, m.metric) for m in request.metrics],
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
refresh=request.refresh,
|
||||
@@ -622,13 +628,12 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
metrics=[
|
||||
MetricEvents(
|
||||
task=task,
|
||||
metric=metric,
|
||||
iterations=[
|
||||
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||
for iteration in iterations
|
||||
],
|
||||
)
|
||||
for (task, metric, iterations) in result.metric_events
|
||||
for (task, iterations) in result.metric_events
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -6,12 +6,12 @@ from apiserver.apimodels.login import (
|
||||
ServerErrors,
|
||||
)
|
||||
from apiserver.config import info
|
||||
from apiserver.service_repo import endpoint
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
@endpoint("login.supported_modes", response_data_model=GetSupportedModesResponse)
|
||||
def supported_modes(_, __, ___: GetSupportedModesRequest):
|
||||
def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
|
||||
guest_user = FixedUser.get_guest_user()
|
||||
if guest_user:
|
||||
guest = BasicGuestMode(
|
||||
@@ -31,4 +31,10 @@ def supported_modes(_, __, ___: GetSupportedModesRequest):
|
||||
missed_es_upgrade=info.missed_es_upgrade,
|
||||
es_connection_error=info.es_connection_error,
|
||||
),
|
||||
authenticated=call.auth is not None,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("login.logout", min_version="2.13")
|
||||
def logout(call: APICall, _, __):
|
||||
call.result.set_auth_cookie(None)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import Q, EmbeddedDocument
|
||||
@@ -7,36 +8,55 @@ from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.errors.bad_request import InvalidModelId
|
||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest
|
||||
from apiserver.apimodels.batch import BatchResponse, BatchRequest
|
||||
from apiserver.apimodels.models import (
|
||||
CreateModelRequest,
|
||||
CreateModelResponse,
|
||||
PublishModelRequest,
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
GetFrameworksRequest,
|
||||
DeleteModelRequest,
|
||||
DeleteMetadataRequest,
|
||||
AddOrUpdateMetadataRequest,
|
||||
ModelsPublishManyRequest,
|
||||
ModelsDeleteManyRequest,
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import validate_id
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
ModelItem,
|
||||
TaskModelNames,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import (
|
||||
parse_from_call,
|
||||
get_company_or_none_constraint,
|
||||
filter_fields,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
validate_metadata,
|
||||
get_metadata_from_api,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
model_bll = ModelBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
@@ -61,19 +81,20 @@ def get_by_id(call: APICall, company_id, _):
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
def get_by_task_id(call: APICall, company_id, _):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
|
||||
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["output"], **query)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="output")
|
||||
if not task.output.model:
|
||||
raise errors.bad_request.MissingTaskFields(field="output.model")
|
||||
if not task.models or not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
|
||||
model_id = task.output.model
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
@@ -86,10 +107,22 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
include_subprojects = call_data.pop("include_subprojects", False)
|
||||
project_ids = call_data.get("project")
|
||||
if not project_ids or not include_subprojects:
|
||||
return
|
||||
|
||||
if not isinstance(project_ids, list):
|
||||
project_ids = [project_ids]
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
_process_include_subprojects(call.data)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
@@ -129,7 +162,7 @@ def get_all(call: APICall, company_id, _):
|
||||
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
|
||||
call.result.data = {
|
||||
"frameworks": sorted(
|
||||
model_bll.get_frameworks(company_id, project_ids=request.projects)
|
||||
project_bll.get_model_frameworks(company_id, project_ids=request.projects)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -147,12 +180,18 @@ create_fields = {
|
||||
"design": None,
|
||||
"labels": dict,
|
||||
"ready": None,
|
||||
"metadata": list,
|
||||
}
|
||||
|
||||
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata")
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
metadata = fields.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
return fields
|
||||
|
||||
|
||||
@@ -174,6 +213,9 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call: APICall, company_id, _):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
|
||||
|
||||
task_id = call.data["task"]
|
||||
uri = call.data.get("uri")
|
||||
iteration = call.data.get("iteration")
|
||||
@@ -189,7 +231,7 @@ def update_for_task(call: APICall, company_id, _):
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["output", "execution", "name", "status", "project"],
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
@@ -202,10 +244,9 @@ def update_for_task(call: APICall, company_id, _):
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
query = dict(company=company_id, id=override_model_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
@@ -214,12 +255,11 @@ def update_for_task(call: APICall, company_id, _):
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.output and task.output.model:
|
||||
if task.models and task.models.output:
|
||||
# model exists, update
|
||||
res = _update_model(
|
||||
call, company_id, model_id=task.output.model
|
||||
).to_struct()
|
||||
res.update({"id": task.output.model, "created": False})
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
@@ -227,14 +267,18 @@ def update_for_task(call: APICall, company_id, _):
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=datetime.utcnow(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.execution.model,
|
||||
parent=task.models.input[0].model
|
||||
if task.models and task.models.input
|
||||
else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
@@ -247,7 +291,13 @@ def update_for_task(call: APICall, company_id, _):
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
output__model=model.id,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
@@ -277,12 +327,16 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
|
||||
validate_metadata(fields.get("metadata"))
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=datetime.utcnow(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
@@ -335,10 +389,9 @@ def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
@@ -363,6 +416,9 @@ def edit(call: APICall, company_id, _):
|
||||
)
|
||||
|
||||
if fields:
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
@@ -384,11 +440,9 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
# get model by id
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
@@ -399,8 +453,15 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
metadata = data.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
@@ -424,84 +485,130 @@ def update(call, company_id, _):
|
||||
request_data_model=PublishModelRequest,
|
||||
response_data_model=PublishModelResponse,
|
||||
)
|
||||
def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
|
||||
updated, published_task_data = TaskBLL.model_set_ready(
|
||||
model_id=req_model.model,
|
||||
def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
|
||||
updated, published_task = ModelBLL.publish_model(
|
||||
model_id=request.model,
|
||||
company_id=company_id,
|
||||
publish_task=req_model.publish_task,
|
||||
force_publish_task=req_model.force_publish_task,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
)
|
||||
|
||||
call.result.data_model = PublishModelResponse(
|
||||
updated=updated,
|
||||
published_task=ModelTaskPublishResponse(**published_task_data)
|
||||
if published_task_data
|
||||
else None,
|
||||
updated=updated, published_task=published_task
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.delete", required_fields=["model"])
|
||||
def update(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
force = call.data.get("force", False)
|
||||
@endpoint(
|
||||
"models.publish_many",
|
||||
request_data_model=ModelsPublishManyRequest,
|
||||
response_data_model=BatchResponse,
|
||||
)
|
||||
def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
ModelBLL.publish_model,
|
||||
company_id=company_id,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).only("id", "task", "project").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
deleted_model_id = f"__DELETED__{model_id}"
|
||||
|
||||
using_tasks = Task.objects(execution__model=model_id).only("id")
|
||||
if using_tasks:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelInUse(
|
||||
"as execution model, use force=True to delete",
|
||||
num_tasks=len(using_tasks),
|
||||
)
|
||||
# update deleted model id in using tasks
|
||||
using_tasks.update(
|
||||
execution__model=deleted_model_id, upsert=False, multi=True
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[
|
||||
dict(
|
||||
id=_id,
|
||||
updated=bool(updated),
|
||||
published_task=published_task.to_struct() if published_task else None,
|
||||
)
|
||||
for _id, (updated, published_task) in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
if model.task:
|
||||
task = Task.objects(id=model.task).first()
|
||||
if task and task.status == TaskStatus.published:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelCreatingTaskExists(
|
||||
"and published, use force=True to delete", task=model.task
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
task.update(
|
||||
output__model=deleted_model_id,
|
||||
output__error=f"model deleted on {now.isoformat()}",
|
||||
last_change=now,
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
del_count = Model.objects(**query).delete()
|
||||
if del_count:
|
||||
_reset_cached_tags(company_id, projects=[model.project])
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
@endpoint("models.delete", request_data_model=DeleteModelRequest)
|
||||
def delete(call: APICall, company_id, request: DeleteModelRequest):
|
||||
del_count, model = ModelBLL.delete_model(
|
||||
model_id=request.model, company_id=company_id, force=request.force
|
||||
)
|
||||
if del_count:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[model.project] if model.project else []
|
||||
)
|
||||
|
||||
call.result.data = dict(deleted=bool(del_count), url=model.uri)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.delete_many",
|
||||
request_data_model=ModelsDeleteManyRequest,
|
||||
response_data_model=BatchResponse,
|
||||
)
|
||||
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force),
|
||||
ids=request.ids,
|
||||
)
|
||||
|
||||
if results:
|
||||
projects = set(model.project for _, (_, model) in results)
|
||||
_reset_cached_tags(company_id, projects=list(projects))
|
||||
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[
|
||||
dict(id=_id, deleted=bool(deleted), url=model.uri)
|
||||
for _id, (deleted, model) in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.archive_many",
|
||||
request_data_model=BatchRequest,
|
||||
response_data_model=BatchResponse,
|
||||
)
|
||||
def archive_many(call: APICall, company_id, request: BatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.unarchive_many",
|
||||
request_data_model=BatchRequest,
|
||||
response_data_model=BatchResponse,
|
||||
)
|
||||
def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[
|
||||
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Model.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
|
||||
)
|
||||
call.result.data = Model.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Model.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
|
||||
)
|
||||
call.result.data = Model.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.move", request_data_model=MoveRequest)
|
||||
@@ -511,14 +618,43 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
"project or project_name is required"
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
return {
|
||||
"project_id": project_bll.move_under_project(
|
||||
entity_cls=Model,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
ids=request.ids,
|
||||
project=request.project,
|
||||
project_name=request.project_name,
|
||||
)
|
||||
}
|
||||
return {
|
||||
"project_id": project_bll.move_under_project(
|
||||
entity_cls=Model,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
ids=request.ids,
|
||||
project=request.project,
|
||||
project_name=request.project_name,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
updated = metadata_add_or_update(
|
||||
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
|
||||
|
||||
@endpoint("models.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
|
||||
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
|
||||
@@ -1,31 +1,30 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
import dpath
|
||||
import attr
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
||||
from apiserver.apimodels.projects import (
|
||||
GetHyperParamReq,
|
||||
ProjectReq,
|
||||
GetHyperParamRequest,
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
ProjectsGetRequest,
|
||||
DeleteRequest,
|
||||
MoveRequest,
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project.project_cleanup import delete_project
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.database.utils import (
|
||||
parse_from_call,
|
||||
get_options,
|
||||
get_company_or_none_constraint,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
@@ -39,7 +38,7 @@ from apiserver.timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
task_bll = TaskBLL()
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
@@ -49,10 +48,6 @@ create_fields = {
|
||||
"default_output_destination": None,
|
||||
}
|
||||
|
||||
get_all_query_options = Project.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
|
||||
@endpoint("projects.get_by_id", required_fields=["project"])
|
||||
def get_by_id(call):
|
||||
@@ -74,210 +69,88 @@ def get_by_id(call):
|
||||
call.result.data = {"project": project_dict}
|
||||
|
||||
|
||||
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
|
||||
archived = EntityVisibility.archived.value
|
||||
def _adjust_search_parameters(data: dict, shallow_search: bool):
|
||||
"""
|
||||
1. Make sure that there is no external query on path
|
||||
2. If not shallow_search and parent is provided then parent can be at any place in path
|
||||
3. If shallow_search and no parent provided then use a top level parent
|
||||
"""
|
||||
data.pop("path", None)
|
||||
if not shallow_search:
|
||||
if "parent" in data:
|
||||
data["path"] = data.pop("parent")
|
||||
return
|
||||
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
"""
|
||||
return {
|
||||
"$addFields": {
|
||||
"system_tags": {
|
||||
"$cond": {
|
||||
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
||||
"then": [],
|
||||
"else": "$system_tags",
|
||||
}
|
||||
},
|
||||
"status": {"$ifNull": ["$status", "unknown"]},
|
||||
}
|
||||
}
|
||||
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"project": "$project",
|
||||
"status": "$status",
|
||||
archived: archived_tasks_cond,
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
# for each project, create a list of (status, count, archived)
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.project",
|
||||
"counts": {
|
||||
"$push": {
|
||||
"status": "$_id.status",
|
||||
"count": "$count",
|
||||
archived: "$_id.%s" % archived,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed and started and completed > started
|
||||
"if": {
|
||||
"$and": [
|
||||
"$started",
|
||||
"$completed",
|
||||
{"$gt": ["$completed", "$started"]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
# then: floor((completed - started) / 1000)
|
||||
"then": {
|
||||
"$floor": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$completed", "$started"]},
|
||||
1000.0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(archived_tasks_cond)
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"type": {"$in": ["training", "testing"]},
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
},
|
||||
]
|
||||
|
||||
return status_count_pipeline, runtime_pipeline
|
||||
if "parent" not in data:
|
||||
data["parent"] = [None]
|
||||
|
||||
|
||||
@endpoint("projects.get_all_ex")
|
||||
def get_all_ex(call: APICall):
|
||||
include_stats = call.data.get("include_stats")
|
||||
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
||||
allow_public = not call.data.get("non_public", False)
|
||||
|
||||
if stats_for_state:
|
||||
try:
|
||||
specific_state = EntityVisibility(stats_for_state)
|
||||
except ValueError:
|
||||
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
|
||||
else:
|
||||
specific_state = None
|
||||
|
||||
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
projects = Project.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
allow_public = not request.non_public
|
||||
data = call.data
|
||||
requested_ids = data.get("id")
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
data = call.data
|
||||
if request.active_users:
|
||||
ids = project_bll.get_projects_with_active_user(
|
||||
company=company_id,
|
||||
users=request.active_users,
|
||||
project_ids=requested_ids,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
if not ids:
|
||||
call.result.data = {"projects": []}
|
||||
return
|
||||
data["id"] = ids
|
||||
|
||||
if not include_stats:
|
||||
_adjust_search_parameters(data, shallow_search=request.shallow_search)
|
||||
|
||||
projects = Project.get_many_with_join(
|
||||
company=company_id, query_dict=data, allow_public=allow_public,
|
||||
)
|
||||
|
||||
if request.check_own_contents and requested_ids:
|
||||
existing_requested_ids = {
|
||||
project["id"] for project in projects if project["id"] in requested_ids
|
||||
}
|
||||
if existing_requested_ids:
|
||||
contents = project_bll.calc_own_contents(
|
||||
company=company_id, project_ids=list(existing_requested_ids)
|
||||
)
|
||||
for project in projects:
|
||||
project.update(**contents.get(project["id"], {}))
|
||||
|
||||
conform_output_tags(call, projects)
|
||||
if not request.include_stats:
|
||||
call.result.data = {"projects": projects}
|
||||
return
|
||||
|
||||
ids = [project["id"] for project in projects]
|
||||
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
||||
call.identity.company, ids, specific_state=specific_state
|
||||
project_ids = {project["id"] for project in projects}
|
||||
stats, children = project_bll.get_project_stats(
|
||||
company=company_id,
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
for project in projects:
|
||||
project["stats"] = stats[project["id"]]
|
||||
project["sub_projects"] = children[project["id"]]
|
||||
|
||||
def set_default_count(entry):
|
||||
return dict(default_counts, **entry)
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.aggregate(status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
).value
|
||||
status_count[result["_id"]][section] = set_default_count(
|
||||
{
|
||||
count_entry["status"]: count_entry["count"]
|
||||
for count_entry in group
|
||||
}
|
||||
)
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
}
|
||||
|
||||
def safe_get(obj, path, default=None):
|
||||
try:
|
||||
return dpath.get(obj, path)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
path = "/".join((project_id, section))
|
||||
return {
|
||||
"total_runtime": safe_get(runtime, path, 0),
|
||||
"status_count": safe_get(status_count, path, default_counts),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
for project in projects:
|
||||
project["stats"] = {
|
||||
task_state.value: get_status_counts(project["id"], task_state.value)
|
||||
for task_state in report_for_states
|
||||
}
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
data = call.data
|
||||
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
parameters=call.data,
|
||||
query_dict=data,
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
@@ -286,9 +159,7 @@ def get_all(call: APICall):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.create",
|
||||
required_fields=["name", "description"],
|
||||
response_data_model=IdResponse,
|
||||
"projects.create", required_fields=["name"], response_data_model=IdResponse,
|
||||
)
|
||||
def create(call: APICall):
|
||||
identity = call.identity
|
||||
@@ -316,61 +187,72 @@ def update(call: APICall):
|
||||
:return: updated - `int` - number of projects updated
|
||||
fields - `[string]` - updated fields
|
||||
"""
|
||||
project_id = call.data["project"]
|
||||
|
||||
with translate_errors_context():
|
||||
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
updated = project.update(upsert=False, **fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
updated = ProjectBLL.update(
|
||||
company=call.identity.company, project_id=call.data["project"], **fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
@endpoint("projects.delete", required_fields=["project"])
|
||||
def delete(call):
|
||||
assert isinstance(call, APICall)
|
||||
project_id = call.data["project"]
|
||||
force = call.data.get("force", False)
|
||||
|
||||
with translate_errors_context():
|
||||
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
|
||||
# have the correct project ID.
|
||||
|
||||
# Find the tasks which belong to the project
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
res = cls.objects(
|
||||
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
|
||||
).only("id")
|
||||
if res and not force:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
|
||||
updated_count = res.update(project=None)
|
||||
|
||||
project.delete()
|
||||
|
||||
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(company, Tags.Task, projects=projects)
|
||||
org_bll.reset_tags(company, Tags.Model, projects=projects)
|
||||
|
||||
|
||||
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
|
||||
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
|
||||
@endpoint("projects.move", request_data_model=MoveRequest)
|
||||
def move(call: APICall, company: str, request: MoveRequest):
|
||||
moved, affected_projects = ProjectBLL.move_project(
|
||||
company=company,
|
||||
user=call.identity.user,
|
||||
project_id=request.project,
|
||||
new_location=request.new_location,
|
||||
)
|
||||
_reset_cached_tags(company, projects=list(affected_projects))
|
||||
|
||||
call.result.data = {"moved": moved}
|
||||
|
||||
|
||||
@endpoint("projects.merge", request_data_model=MergeRequest)
|
||||
def merge(call: APICall, company: str, request: MergeRequest):
|
||||
moved_entitites, moved_projects, affected_projects = ProjectBLL.merge_project(
|
||||
company, source_id=request.project, destination_id=request.destination_project
|
||||
)
|
||||
|
||||
_reset_cached_tags(company, projects=list(affected_projects))
|
||||
|
||||
call.result.data = {
|
||||
"moved_entities": moved_entitites,
|
||||
"moved_projects": moved_projects,
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.delete", request_data_model=DeleteRequest)
|
||||
def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
res, affected_projects = delete_project(
|
||||
company=company_id,
|
||||
project_id=request.project,
|
||||
force=request.force,
|
||||
delete_contents=request.delete_contents,
|
||||
)
|
||||
_reset_cached_tags(company_id, projects=list(affected_projects))
|
||||
call.result.data = {**attr.asdict(res)}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
|
||||
)
|
||||
def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
||||
):
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
company_id, [request.project] if request.project else None
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
)
|
||||
|
||||
call.result.data = {"metrics": metrics}
|
||||
@@ -379,13 +261,14 @@ def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectR
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
request_data_model=GetHyperParamReq,
|
||||
request_data_model=GetHyperParamRequest,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
|
||||
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
@@ -397,6 +280,28 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyperparam_values",
|
||||
min_version="2.13",
|
||||
request_data_model=ProjectHyperparamValuesRequest,
|
||||
)
|
||||
def get_hyperparam_values(
|
||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||
):
|
||||
total, values = task_bll.get_hyperparam_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
name=request.name,
|
||||
include_subprojects=request.include_subprojects,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"values": values,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
@@ -452,7 +357,10 @@ def get_task_parents(
|
||||
call: APICall, company_id: str, request: ProjectTaskParentsRequest
|
||||
):
|
||||
call.result.data = {
|
||||
"parents": org_bll.get_parent_tasks(
|
||||
company_id, projects=request.projects, state=request.tasks_state
|
||||
"parents": project_bll.get_task_parents(
|
||||
company_id,
|
||||
projects=request.projects,
|
||||
include_subprojects=request.include_subprojects,
|
||||
state=request.tasks_state,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -11,12 +11,21 @@ from apiserver.apimodels.queues import (
|
||||
GetMetricsRequest,
|
||||
GetMetricsResponse,
|
||||
QueueMetrics,
|
||||
AddOrUpdateMetadataRequest,
|
||||
DeleteMetadataRequest,
|
||||
)
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.util import extract_properties_to_lists
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, conform_tags
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
conform_tags,
|
||||
get_metadata_from_api,
|
||||
)
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
|
||||
worker_bll = WorkerBLL()
|
||||
queue_bll = QueueBLL(worker_bll)
|
||||
@@ -62,7 +71,11 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
call, request.tags, request.system_tags, validate=True
|
||||
)
|
||||
queue = queue_bll.create(
|
||||
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
|
||||
company_id=company_id,
|
||||
name=request.name,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
metadata=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
call.result.data = {"id": queue.id}
|
||||
|
||||
@@ -220,3 +233,25 @@ def get_queue_metrics(
|
||||
for queue, data in queue_dicts.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
return {
|
||||
"updated": metadata_add_or_update(
|
||||
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
|
||||
from functools import partial
|
||||
from typing import Sequence, Union, Tuple
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
import mongoengine
|
||||
from mongoengine import EmbeddedDocument, Q
|
||||
from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.errors.bad_request import InvalidTaskId
|
||||
from apiserver.apimodels.base import (
|
||||
UpdateResponse,
|
||||
@@ -18,11 +17,15 @@ from apiserver.apimodels.base import (
|
||||
MakePublicRequest,
|
||||
MoveRequest,
|
||||
)
|
||||
from apiserver.apimodels.batch import (
|
||||
BatchResponse,
|
||||
UpdateBatchResponse,
|
||||
UpdateBatchItem,
|
||||
)
|
||||
from apiserver.apimodels.tasks import (
|
||||
StartedResponse,
|
||||
ResetResponse,
|
||||
PublishRequest,
|
||||
PublishResponse,
|
||||
CreateRequest,
|
||||
UpdateRequest,
|
||||
SetRequirementsRequest,
|
||||
@@ -46,16 +49,30 @@ from apiserver.apimodels.tasks import (
|
||||
DeleteArtifactsRequest,
|
||||
ArchiveResponse,
|
||||
ArchiveRequest,
|
||||
AddUpdateModelRequest,
|
||||
DeleteModelsRequest,
|
||||
StopManyRequest,
|
||||
EnqueueManyRequest,
|
||||
ResetManyRequest,
|
||||
DeleteManyRequest,
|
||||
PublishManyRequest,
|
||||
TaskBatchRequest,
|
||||
EnqueueManyResponse,
|
||||
EnqueueBatchItem,
|
||||
DequeueBatchItem,
|
||||
DequeueManyResponse,
|
||||
ResetManyResponse,
|
||||
ResetBatchItem,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
split_by,
|
||||
)
|
||||
from apiserver.bll.task.artifacts import (
|
||||
artifacts_prepare_for_save,
|
||||
@@ -69,25 +86,36 @@ from apiserver.bll.task.param_utils import (
|
||||
params_unprepare_from_saved,
|
||||
escape_paths,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task
|
||||
from apiserver.bll.util import SetFieldsResolver
|
||||
from apiserver.bll.task.task_operations import (
|
||||
stop_task,
|
||||
enqueue_task,
|
||||
dequeue_task,
|
||||
reset_task,
|
||||
archive_task,
|
||||
delete_task,
|
||||
publish_task,
|
||||
unarchive_task,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
|
||||
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
Script,
|
||||
DEFAULT_LAST_ITERATION,
|
||||
Execution,
|
||||
ArtifactModes,
|
||||
ModelItem,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
DockerCmdBackwardsCompatibility,
|
||||
escape_dict_field,
|
||||
unescape_dict_field,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
@@ -155,28 +183,48 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
call.result.data = {"task": task_dict}
|
||||
|
||||
|
||||
def escape_execution_parameters(call: APICall):
|
||||
projection = Task.get_projection(call.data)
|
||||
if projection:
|
||||
Task.set_projection(call.data, escape_paths(projection))
|
||||
def escape_execution_parameters(call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
|
||||
ordering = Task.get_ordering(call.data)
|
||||
keys = list(call.data)
|
||||
call_data = {
|
||||
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = Task.get_projection(call_data)
|
||||
if projection:
|
||||
Task.set_projection(call_data, escape_paths(projection))
|
||||
|
||||
ordering = Task.get_ordering(call_data)
|
||||
if ordering:
|
||||
Task.set_ordering(call.data, escape_paths(ordering))
|
||||
Task.set_ordering(call_data, escape_paths(ordering))
|
||||
|
||||
return call_data
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
include_subprojects = call_data.pop("include_subprojects", False)
|
||||
project_ids = call_data.get("project")
|
||||
if not project_ids or not include_subprojects:
|
||||
return
|
||||
|
||||
if not isinstance(project_ids, list):
|
||||
project_ids = [project_ids]
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
@@ -186,12 +234,12 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True,
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
@@ -202,15 +250,15 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
@@ -219,7 +267,9 @@ def get_all(call: APICall, company_id, _):
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
def get_types(call: APICall, company_id, request: GetTypesRequest):
|
||||
call.result.data = {
|
||||
"types": list(task_bll.get_types(company_id, project_ids=request.projects))
|
||||
"types": list(
|
||||
project_bll.get_task_types(company_id, project_ids=request.projects)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -236,7 +286,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
|
||||
|
||||
"""
|
||||
call.result.data_model = UpdateResponse(
|
||||
**TaskBLL.stop_task(
|
||||
**stop_task(
|
||||
task_id=req_model.task,
|
||||
company_id=company_id,
|
||||
user_name=call.identity.user_name,
|
||||
@@ -246,6 +296,28 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.stop_many",
|
||||
request_data_model=StopManyRequest,
|
||||
response_data_model=UpdateBatchResponse,
|
||||
)
|
||||
def stop_many(call: APICall, company_id, request: StopManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
stop_task,
|
||||
company_id=company_id,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=request.status_reason,
|
||||
force=request.force,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = UpdateBatchResponse(
|
||||
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.stopped",
|
||||
request_data_model=UpdateRequest,
|
||||
@@ -308,18 +380,27 @@ create_fields = {
|
||||
"parent": Task,
|
||||
"project": None,
|
||||
"input": None,
|
||||
"models": None,
|
||||
"container": None,
|
||||
"output_dest": None,
|
||||
"execution": None,
|
||||
"hyperparams": None,
|
||||
"configuration": None,
|
||||
"script": None,
|
||||
"runtime": None,
|
||||
}
|
||||
|
||||
dict_fields_paths = [("execution", "model_labels"), "container"]
|
||||
|
||||
|
||||
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
params_prepare_for_save(fields, previous_task=previous_task)
|
||||
artifacts_prepare_for_save(fields)
|
||||
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
|
||||
DockerCmdBackwardsCompatibility.prepare_for_save(call, fields)
|
||||
for path in dict_fields_paths:
|
||||
escape_dict_field(fields, path)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_stripped_fields:
|
||||
@@ -342,7 +423,15 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
|
||||
conform_output_tags(call, tasks_data)
|
||||
|
||||
for data in tasks_data:
|
||||
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
|
||||
for path in dict_fields_paths:
|
||||
unescape_dict_field(data, path)
|
||||
|
||||
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
||||
DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
||||
|
||||
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
|
||||
|
||||
for data in tasks_data:
|
||||
params_unprepare_from_saved(
|
||||
fields=data, copy_to_legacy=need_legacy_params,
|
||||
)
|
||||
@@ -368,6 +457,17 @@ def prepare_create_fields(
|
||||
output = Output(destination=output_dest)
|
||||
fields["output"] = output
|
||||
|
||||
# Add models updated time
|
||||
models = fields.get("models")
|
||||
if models:
|
||||
now = datetime.utcnow()
|
||||
for field in (TaskModelTypes.input, TaskModelTypes.output):
|
||||
field_models = models.get(field)
|
||||
if not field_models:
|
||||
continue
|
||||
for model in field_models:
|
||||
model["updated"] = now
|
||||
|
||||
return prepare_for_save(call, fields, previous_task=previous_task)
|
||||
|
||||
|
||||
@@ -386,6 +486,9 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
|
||||
|
||||
@endpoint("tasks.validate", request_data_model=CreateRequest)
|
||||
def validate(call: APICall, company_id, req_model: CreateRequest):
|
||||
parent = call.data.get("parent")
|
||||
if parent and parent.startswith(deleted_prefix):
|
||||
call.data.pop("parent")
|
||||
_validate_and_get_task_from_call(call)
|
||||
|
||||
|
||||
@@ -431,7 +534,9 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
system_tags=request.new_task_system_tags,
|
||||
hyperparams=request.new_task_hyperparams,
|
||||
configuration=request.new_task_configuration,
|
||||
container=request.new_task_container,
|
||||
execution_overrides=request.execution_overrides,
|
||||
input_models=request.new_task_input_models,
|
||||
validate_references=request.validate_references,
|
||||
new_project_name=request.new_project_name,
|
||||
)
|
||||
@@ -698,7 +803,7 @@ def get_configuration_names(
|
||||
):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks
|
||||
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
@@ -742,61 +847,41 @@ def delete_configuration(
|
||||
request_data_model=EnqueueRequest,
|
||||
response_data_model=EnqueueResponse,
|
||||
)
|
||||
def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
|
||||
task_id = req_model.task
|
||||
queue_id = req_model.queue
|
||||
status_message = req_model.status_message
|
||||
status_reason = req_model.status_reason
|
||||
def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
||||
queued, res = enqueue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
)
|
||||
call.result.data_model = EnqueueResponse(queued=queued, **res)
|
||||
|
||||
if not queue_id:
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
_only=("type", "script", "execution", "status", "project", "id"), **query
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
res = EnqueueResponse(
|
||||
**ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.queued,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
).execute()
|
||||
)
|
||||
|
||||
try:
|
||||
queue_bll.add_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task.id
|
||||
)
|
||||
except Exception:
|
||||
# failed enqueueing, revert to previous state
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
current_status_override=TaskStatus.queued,
|
||||
new_status=task.status,
|
||||
force=True,
|
||||
status_reason="failed enqueueing",
|
||||
).execute()
|
||||
raise
|
||||
|
||||
# set the current queue ID in the task
|
||||
if task.execution:
|
||||
Task.objects(**query).update(execution__queue=queue_id, multi=False)
|
||||
else:
|
||||
Task.objects(**query).update(
|
||||
execution=Execution(queue=queue_id), multi=False
|
||||
)
|
||||
|
||||
res.queued = 1
|
||||
res.fields.update(**{"execution.queue": queue_id})
|
||||
|
||||
call.result.data_model = res
|
||||
@endpoint(
|
||||
"tasks.enqueue_many",
|
||||
request_data_model=EnqueueManyRequest,
|
||||
response_data_model=EnqueueManyResponse,
|
||||
)
|
||||
def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
enqueue_task,
|
||||
company_id=company_id,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
validate=request.validate_tasks,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = EnqueueManyResponse(
|
||||
succeeded=[
|
||||
EnqueueBatchItem(id=_id, queued=bool(queued), **res)
|
||||
for _id, (queued, res) in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -805,333 +890,250 @@ def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
|
||||
response_data_model=DequeueResponse,
|
||||
)
|
||||
def dequeue(call: APICall, company_id, request: UpdateRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
request.task,
|
||||
dequeued, res = dequeue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
only=("id", "execution", "status", "project"),
|
||||
requires_write_access=True,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
)
|
||||
res = DequeueResponse(
|
||||
**TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id,
|
||||
call.result.data_model = DequeueResponse(dequeued=dequeued, **res)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.dequeue_many",
|
||||
request_data_model=TaskBatchRequest,
|
||||
response_data_model=DequeueManyResponse,
|
||||
)
|
||||
def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
dequeue_task,
|
||||
company_id=company_id,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
)
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = DequeueManyResponse(
|
||||
succeeded=[
|
||||
DequeueBatchItem(id=_id, dequeued=bool(dequeued), **res)
|
||||
for _id, (dequeued, res) in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
res.dequeued = 1
|
||||
call.result.data_model = res
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
||||
)
|
||||
def reset(call: APICall, company_id, request: ResetRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
request.task, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
force = request.force
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
||||
|
||||
api_results = {}
|
||||
updates = {}
|
||||
|
||||
try:
|
||||
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
else:
|
||||
if dequeued:
|
||||
api_results.update(dequeued=dequeued)
|
||||
|
||||
cleaned_up = cleanup_task(task, force)
|
||||
api_results.update(attr.asdict(cleaned_up))
|
||||
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__metric_stats={},
|
||||
unset__output__result=1,
|
||||
unset__output__model=1,
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
)
|
||||
|
||||
if request.clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(), unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
if task.execution and task.execution.artifacts:
|
||||
updates.update(
|
||||
set__execution__artifacts={
|
||||
key: artifact
|
||||
for key, artifact in task.execution.artifacts.items()
|
||||
if artifact.mode == ArtifactModes.input
|
||||
}
|
||||
)
|
||||
|
||||
res = ResetResponse(
|
||||
**ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
force=force,
|
||||
status_reason="reset",
|
||||
status_message="reset",
|
||||
).execute(
|
||||
started=None,
|
||||
completed=None,
|
||||
published=None,
|
||||
active_duration=None,
|
||||
**updates,
|
||||
)
|
||||
dequeued, cleanup_res, updates = reset_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
clear_all=request.clear_all,
|
||||
)
|
||||
|
||||
res = ResetResponse(**updates, dequeued=dequeued)
|
||||
# do not return artifacts since they are not serializable
|
||||
res.fields.pop("execution.artifacts", None)
|
||||
|
||||
for key, value in api_results.items():
|
||||
for key, value in attr.asdict(cleanup_res).items():
|
||||
setattr(res, key, value)
|
||||
|
||||
call.result.data_model = res
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.reset_many",
|
||||
request_data_model=ResetManyRequest,
|
||||
response_data_model=ResetManyResponse,
|
||||
)
|
||||
def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
reset_task,
|
||||
company_id=company_id,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
clear_all=request.clear_all,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
|
||||
def clean_res(res: dict) -> dict:
|
||||
# do not return artifacts since they are not serializable
|
||||
fields = res.get("fields")
|
||||
if fields:
|
||||
fields.pop("execution.artifacts", None)
|
||||
return res
|
||||
|
||||
call.result.data_model = ResetManyResponse(
|
||||
succeeded=[
|
||||
ResetBatchItem(
|
||||
id=_id,
|
||||
dequeued=bool(dequeued.get("removed")) if dequeued else False,
|
||||
**attr.asdict(cleanup),
|
||||
**clean_res(res),
|
||||
)
|
||||
for _id, (dequeued, cleanup, res) in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.archive",
|
||||
request_data_model=ArchiveRequest,
|
||||
response_data_model=ArchiveResponse,
|
||||
)
|
||||
def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
archived = 0
|
||||
tasks = TaskBLL.assert_exists(
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
only=("id", "execution", "status", "project", "system_tags"),
|
||||
only=("id", "execution", "status", "project", "system_tags", "enqueue_status"),
|
||||
)
|
||||
archived = 0
|
||||
for task in tasks:
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task, company_id, request.status_message, request.status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
task.update(
|
||||
archived += archive_task(
|
||||
company_id=company_id,
|
||||
task=task,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
system_tags=sorted(
|
||||
set(task.system_tags) | {EntityVisibility.archived.value}
|
||||
),
|
||||
last_change=datetime.utcnow(),
|
||||
)
|
||||
|
||||
archived += 1
|
||||
|
||||
call.result.data_model = ArchiveResponse(archived=archived)
|
||||
|
||||
|
||||
class DocumentGroup(list):
|
||||
"""
|
||||
Operate on a list of documents as if they were a query result
|
||||
"""
|
||||
|
||||
def __init__(self, document_type, documents):
|
||||
super(DocumentGroup, self).__init__(documents)
|
||||
self.type = document_type
|
||||
|
||||
def objects(self, *args, **kwargs):
|
||||
return self.type.objects(id__in=[obj.id for obj in self], *args, **kwargs)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskOutputs(object):
|
||||
"""
|
||||
Split task outputs of the same type by the ready state
|
||||
"""
|
||||
|
||||
published = None # type: DocumentGroup
|
||||
draft = None # type: DocumentGroup
|
||||
|
||||
def __init__(self, is_published, document_type, children):
|
||||
# type: (Callable[[T], bool], Type[mongoengine.Document], Sequence[T]) -> ()
|
||||
"""
|
||||
:param is_published: predicate returning whether items is considered published
|
||||
:param document_type: type of output
|
||||
:param children: output documents
|
||||
"""
|
||||
self.published, self.draft = map(
|
||||
lambda x: DocumentGroup(document_type, x), split_by(is_published, children)
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class CleanupResult(object):
|
||||
"""
|
||||
Counts of objects modified in task cleanup operation
|
||||
"""
|
||||
|
||||
updated_children = attr.ib(type=int)
|
||||
updated_models = attr.ib(type=int)
|
||||
deleted_models = attr.ib(type=int)
|
||||
|
||||
|
||||
def cleanup_task(task: Task, force: bool = False):
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
:param task: task object
|
||||
:param force: whether to delete task with published outputs
|
||||
:return: count of delete and modified items
|
||||
"""
|
||||
models, child_tasks = get_outputs_for_deletion(task, force)
|
||||
deleted_task_id = trash_task_id(task.id)
|
||||
if child_tasks:
|
||||
with TimingContext("mongo", "update_task_children"):
|
||||
updated_children = child_tasks.update(parent=deleted_task_id)
|
||||
else:
|
||||
updated_children = 0
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
deleted_models = models.draft.objects().delete()
|
||||
else:
|
||||
deleted_models = 0
|
||||
|
||||
if models.published:
|
||||
with TimingContext("mongo", "update_task_models"):
|
||||
updated_models = models.published.objects().update(task=deleted_task_id)
|
||||
else:
|
||||
updated_models = 0
|
||||
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
updated_children=updated_children,
|
||||
updated_models=updated_models,
|
||||
@endpoint(
|
||||
"tasks.archive_many",
|
||||
request_data_model=TaskBatchRequest,
|
||||
response_data_model=BatchResponse,
|
||||
)
|
||||
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
archive_task,
|
||||
company_id=company_id,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
def get_outputs_for_deletion(task, force=False):
|
||||
with TimingContext("mongo", "get_task_models"):
|
||||
models = TaskOutputs(
|
||||
attrgetter("ready"),
|
||||
Model,
|
||||
Model.objects(task=task.id).only("id", "task", "ready"),
|
||||
)
|
||||
if not force and models.published:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
task=task.id,
|
||||
models=len(models.published),
|
||||
)
|
||||
|
||||
if task.output.model:
|
||||
output_model = get_output_model(task, force)
|
||||
if output_model:
|
||||
if output_model.ready:
|
||||
models.published.append(output_model)
|
||||
else:
|
||||
models.draft.append(output_model)
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = [m.id for m in models.draft]
|
||||
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
|
||||
"id", "execution.model"
|
||||
)
|
||||
busy_models = [t.execution.model for t in dependent_tasks]
|
||||
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
|
||||
|
||||
with TimingContext("mongo", "get_task_children"):
|
||||
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
|
||||
published_tasks = [
|
||||
task for task in tasks if task.status == TaskStatus.published
|
||||
]
|
||||
if not force and published_tasks:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True", task=task.id, children=published_tasks
|
||||
)
|
||||
return models, tasks
|
||||
|
||||
|
||||
def get_output_model(task, force=False):
|
||||
with TimingContext("mongo", "get_task_output_model"):
|
||||
output_model = Model.objects(id=task.output.model).first()
|
||||
if output_model and output_model.ready and not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True", task=task.id, model=task.output.model
|
||||
)
|
||||
return output_model
|
||||
|
||||
|
||||
def trash_task_id(task_id):
|
||||
return "__DELETED__{}".format(task_id)
|
||||
@endpoint(
|
||||
"tasks.unarchive_many",
|
||||
request_data_model=TaskBatchRequest,
|
||||
response_data_model=BatchResponse,
|
||||
)
|
||||
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
unarchive_task,
|
||||
company_id=company_id,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[
|
||||
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.delete", request_data_model=DeleteRequest)
|
||||
def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
req_model.task, company_id=company_id, requires_write_access=True
|
||||
def delete(call: APICall, company_id, request: DeleteRequest):
|
||||
deleted, task, cleanup_res = delete_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
)
|
||||
if deleted:
|
||||
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
|
||||
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
|
||||
|
||||
|
||||
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
|
||||
def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
|
||||
move_to_trash = req_model.move_to_trash
|
||||
force = req_model.force
|
||||
if results:
|
||||
projects = set(task.project for _, (_, task, _) in results)
|
||||
_reset_cached_tags(company_id, projects=list(projects))
|
||||
|
||||
if task.status != TaskStatus.created and not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"due to status, use force=True",
|
||||
task=task.id,
|
||||
expected=TaskStatus.created,
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
result = cleanup_task(task, force)
|
||||
|
||||
if move_to_trash:
|
||||
collection_name = task._get_collection_name()
|
||||
archived_collection = "{}__trash".format(collection_name)
|
||||
task.switch_collection(archived_collection)
|
||||
try:
|
||||
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
|
||||
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
|
||||
with TimingContext("mongo", "save_task"):
|
||||
task.save(force_insert=True)
|
||||
except Exception:
|
||||
pass
|
||||
task.switch_collection(collection_name)
|
||||
|
||||
task.delete()
|
||||
_reset_cached_tags(company_id, projects=[task.project])
|
||||
update_project_time(task.project)
|
||||
|
||||
call.result.data = dict(deleted=True, **attr.asdict(result))
|
||||
call.result.data = dict(
|
||||
succeeded=[
|
||||
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
|
||||
for _id, (deleted, _, cleanup_res) in results
|
||||
],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.publish",
|
||||
request_data_model=PublishRequest,
|
||||
response_data_model=PublishResponse,
|
||||
response_data_model=UpdateResponse,
|
||||
)
|
||||
def publish(call: APICall, company_id, req_model: PublishRequest):
|
||||
call.result.data_model = PublishResponse(
|
||||
**TaskBLL.publish_task(
|
||||
task_id=req_model.task,
|
||||
def publish(call: APICall, company_id, request: PublishRequest):
|
||||
updates = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
|
||||
status_reason=request.status_reason,
|
||||
status_message=request.status_message,
|
||||
)
|
||||
call.result.data_model = UpdateResponse(**updates)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.publish_many",
|
||||
request_data_model=PublishManyRequest,
|
||||
response_data_model=UpdateBatchResponse,
|
||||
)
|
||||
def publish_many(call: APICall, company_id, request: PublishManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
publish_task,
|
||||
company_id=company_id,
|
||||
publish_model=req_model.publish_model,
|
||||
force=req_model.force,
|
||||
status_reason=req_model.status_reason,
|
||||
status_message=req_model.status_message,
|
||||
)
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model
|
||||
if request.publish_model
|
||||
else None,
|
||||
status_reason=request.status_reason,
|
||||
status_message=request.status_message,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
|
||||
call.result.data_model = UpdateBatchResponse(
|
||||
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
|
||||
failed=failures,
|
||||
)
|
||||
|
||||
|
||||
@@ -1235,3 +1237,40 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
update_project_time(projects)
|
||||
|
||||
return {"project_id": project_id}
|
||||
|
||||
|
||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
|
||||
models_field = f"models__{request.type}"
|
||||
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
||||
query = {"id": request.task, f"{models_field}__name": request.name}
|
||||
updated = Task.objects(**query).update_one(**{f"set__{models_field}__S": model})
|
||||
|
||||
updated = TaskBLL.update_statistics(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
last_iteration_max=request.iteration,
|
||||
**({f"push__{models_field}": model} if not updated else {}),
|
||||
)
|
||||
|
||||
return {"updated": updated}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_models", min_version="2.13")
|
||||
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
|
||||
delete_names = {
|
||||
type_: [m.name for m in request.models if m.type == type_]
|
||||
for type_ in get_options(TaskModelTypes)
|
||||
}
|
||||
commands = {
|
||||
f"pull__models__{field}__name__in": names
|
||||
for field, names in delete_names.items()
|
||||
if names
|
||||
}
|
||||
|
||||
updated = task.update(last_change=datetime.utcnow(), **commands,)
|
||||
return {"updated": updated}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from datetime import datetime
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||
from apiserver.apimodels.organization import Filter
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
from apiserver.database.utils import partition_tags
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
|
||||
@@ -25,16 +30,23 @@ def get_tags_response(ret: dict) -> dict:
|
||||
|
||||
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
"""
|
||||
Make sure that tags are always returned sorted
|
||||
For old clients both tags and system tags are returned in 'tags' field
|
||||
"""
|
||||
if call.requested_endpoint_version >= PartialVersion("2.3"):
|
||||
return
|
||||
if isinstance(documents, dict):
|
||||
documents = [documents]
|
||||
|
||||
merge_tags = call.requested_endpoint_version < PartialVersion("2.3")
|
||||
for doc in documents:
|
||||
system_tags = doc.get("system_tags")
|
||||
if system_tags:
|
||||
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
|
||||
if merge_tags:
|
||||
system_tags = doc.get("system_tags")
|
||||
if system_tags:
|
||||
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
|
||||
|
||||
for field in ("system_tags", "tags"):
|
||||
tags = doc.get(field)
|
||||
if tags:
|
||||
doc[field] = sorted(tags)
|
||||
|
||||
|
||||
def conform_tag_fields(call: APICall, document: dict, validate=False):
|
||||
@@ -84,3 +96,148 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"unsupported tag prefix", values=unsupported
|
||||
)
|
||||
|
||||
|
||||
def escape_dict(data: dict) -> dict:
|
||||
if not data:
|
||||
return data
|
||||
|
||||
return {ParameterKeyEscaper.escape(k): v for k, v in data.items()}
|
||||
|
||||
|
||||
def unescape_dict(data: dict) -> dict:
|
||||
if not data:
|
||||
return data
|
||||
|
||||
return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()}
|
||||
|
||||
|
||||
def escape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
||||
if isinstance(path, str):
|
||||
path = (path,)
|
||||
|
||||
data = nested_get(fields, path)
|
||||
if not data or not isinstance(data, dict):
|
||||
return
|
||||
|
||||
nested_set(fields, path, escape_dict(data))
|
||||
|
||||
|
||||
def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
||||
if isinstance(path, str):
|
||||
path = (path,)
|
||||
|
||||
data = nested_get(fields, path)
|
||||
if not data or not isinstance(data, dict):
|
||||
return
|
||||
|
||||
nested_set(fields, path, unescape_dict(data))
|
||||
|
||||
|
||||
class ModelsBackwardsCompatibility:
|
||||
max_version = PartialVersion("2.13")
|
||||
mode_to_fields = {
|
||||
TaskModelTypes.input: ("execution", "model"),
|
||||
TaskModelTypes.output: ("output", "model"),
|
||||
}
|
||||
models_field = "models"
|
||||
|
||||
@classmethod
|
||||
def prepare_for_save(cls, call: APICall, fields: dict):
|
||||
if call.requested_endpoint_version >= cls.max_version:
|
||||
return
|
||||
|
||||
for mode, field in cls.mode_to_fields.items():
|
||||
value = nested_get(fields, field)
|
||||
if value is None:
|
||||
continue
|
||||
val = [
|
||||
dict(
|
||||
name=TaskModelNames[mode],
|
||||
model=value,
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
] if value else []
|
||||
nested_set(fields, (cls.models_field, mode), value=val)
|
||||
|
||||
nested_delete(fields, field)
|
||||
|
||||
@classmethod
|
||||
def unprepare_from_saved(
|
||||
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
|
||||
):
|
||||
if call.requested_endpoint_version >= cls.max_version:
|
||||
return
|
||||
|
||||
if isinstance(tasks_data, dict):
|
||||
tasks_data = [tasks_data]
|
||||
|
||||
for task in tasks_data:
|
||||
for mode, field in cls.mode_to_fields.items():
|
||||
models = nested_get(task, (cls.models_field, mode))
|
||||
if not models:
|
||||
continue
|
||||
|
||||
model = models[0] if mode == TaskModelTypes.input else models[-1]
|
||||
if model:
|
||||
nested_set(task, field, model.get("model"))
|
||||
|
||||
|
||||
class DockerCmdBackwardsCompatibility:
|
||||
max_version = PartialVersion("2.13")
|
||||
field = ("execution", "docker_cmd")
|
||||
|
||||
@classmethod
|
||||
def prepare_for_save(cls, call: APICall, fields: dict):
|
||||
if call.requested_endpoint_version >= cls.max_version:
|
||||
return
|
||||
|
||||
docker_cmd = nested_get(fields, cls.field)
|
||||
if docker_cmd is not None:
|
||||
image, _, arguments = docker_cmd.partition(" ")
|
||||
nested_set(fields, ("container", "image"), value=image)
|
||||
nested_set(fields, ("container", "arguments"), value=arguments)
|
||||
|
||||
nested_delete(fields, cls.field)
|
||||
|
||||
@classmethod
|
||||
def unprepare_from_saved(
|
||||
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
|
||||
):
|
||||
if call.requested_endpoint_version >= cls.max_version:
|
||||
return
|
||||
|
||||
if isinstance(tasks_data, dict):
|
||||
tasks_data = [tasks_data]
|
||||
|
||||
for task in tasks_data:
|
||||
container = task.get("container")
|
||||
if not container or not container.get("image"):
|
||||
continue
|
||||
|
||||
docker_cmd = " ".join(
|
||||
filter(None, map(container.get, ("image", "arguments")))
|
||||
)
|
||||
if docker_cmd:
|
||||
nested_set(task, cls.field, docker_cmd)
|
||||
|
||||
|
||||
def validate_metadata(metadata: Sequence[dict]):
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
keys = [m.get("key") for m in metadata]
|
||||
unique_keys = set(keys)
|
||||
unique_keys.discard(None)
|
||||
if len(keys) != len(set(keys)):
|
||||
raise errors.bad_request.ValidationError("Metadata keys should be unique")
|
||||
|
||||
|
||||
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
|
||||
if not api_metadata:
|
||||
return api_metadata
|
||||
|
||||
metadata = [m.to_struct() for m in api_metadata]
|
||||
validate_metadata(metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -23,10 +23,10 @@ from apiserver.apimodels.workers import (
|
||||
GetActivityReportResponse,
|
||||
ActivityReportSeries,
|
||||
)
|
||||
from apiserver.bll.util import extract_properties_to_lists
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from time import sleep
|
||||
@@ -18,10 +19,12 @@ def distributed_lock(name: str, timeout: int, max_wait: int = 0):
|
||||
lock_name = f"dist_lock_{name}"
|
||||
start = time.time()
|
||||
max_wait = max_wait or timeout * 2
|
||||
while not _redis.set(lock_name, value="", ex=timeout, nx=True):
|
||||
pid = os.getpid()
|
||||
while _redis.set(lock_name, value=pid, ex=timeout, nx=True) is None:
|
||||
sleep(1)
|
||||
if time.time() - start > max_wait:
|
||||
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds")
|
||||
holder = _redis.get(lock_name)
|
||||
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds. The lock is hold by {holder}")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
@@ -10,16 +9,14 @@ import requests
|
||||
import six
|
||||
from boltons.iterutils import remap
|
||||
from boltons.typeutils import issubclass
|
||||
from pyhocon import ConfigFactory
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
|
||||
config = ConfigFactory.parse_file("api_client.conf")
|
||||
|
||||
log = logging.getLogger("api_client")
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class APICallResult:
|
||||
@@ -111,7 +108,7 @@ class APIClient:
|
||||
self.api_key = (
|
||||
api_key
|
||||
or os.environ.get("SM_API_KEY")
|
||||
or config.get("api_key")
|
||||
or config.get("apiclient.api_key")
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError("APIClient requires api_key in constructor or config")
|
||||
@@ -119,7 +116,7 @@ class APIClient:
|
||||
self.secret_key = (
|
||||
secret_key
|
||||
or os.environ.get("SM_API_SECRET")
|
||||
or config.get("secret_key")
|
||||
or config.get("apiclient.secret_key")
|
||||
)
|
||||
if not self.secret_key:
|
||||
raise ValueError(
|
||||
@@ -127,7 +124,7 @@ class APIClient:
|
||||
)
|
||||
|
||||
self.base_url = (
|
||||
base_url or os.environ.get("SM_API_URL") or config.get("base_url")
|
||||
base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")
|
||||
)
|
||||
if not self.base_url:
|
||||
raise ValueError("APIClient requires base_url in constructor or config")
|
||||
@@ -139,9 +136,9 @@ class APIClient:
|
||||
|
||||
# create http session
|
||||
self.http_session = requests.session()
|
||||
retries = config.get("retries", 7)
|
||||
backoff_factor = config.get("backoff_factor", 0.3)
|
||||
status_forcelist = config.get("status_forcelist", (500, 502, 504))
|
||||
retries = config.get("apiclient.retries", 7)
|
||||
backoff_factor = config.get("apiclient.backoff_factor", 0.3)
|
||||
status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504))
|
||||
retry = Retry(
|
||||
total=retries,
|
||||
read=retries,
|
||||
@@ -154,7 +151,7 @@ class APIClient:
|
||||
self.http_session.mount("https://", adapter)
|
||||
|
||||
if impersonated_user_id:
|
||||
self.http_session.headers["X-Trains-Impersonate-As"] = impersonated_user_id
|
||||
self.http_session.headers["X-ClearML-Impersonate-As"] = impersonated_user_id
|
||||
|
||||
if not self.session_token:
|
||||
self.login()
|
||||
@@ -211,7 +208,7 @@ class APIClient:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
headers.update(headers_overrides)
|
||||
if is_async:
|
||||
headers["X-Trains-Async"] = "1"
|
||||
headers["X-ClearML-Async"] = "1"
|
||||
|
||||
if not isinstance(data, six.string_types):
|
||||
data = json.dumps(data)
|
||||
@@ -241,7 +238,7 @@ class APIClient:
|
||||
call_id = res.meta.call_id
|
||||
async_res_url = "%s/async.result?id=%s" % (self.base_url, call_id)
|
||||
async_res_headers = headers.copy()
|
||||
async_res_headers.pop("X-Trains-Async")
|
||||
async_res_headers.pop("X-ClearML-Async")
|
||||
while not got_result:
|
||||
log.info("Got 202. Checking async result for %s (%s)" % (url, call_id))
|
||||
http_res = self.http_session.get(
|
||||
|
||||
@@ -5,6 +5,8 @@ from functools import partial
|
||||
from typing import Iterable
|
||||
from unittest import TestCase
|
||||
|
||||
from packaging.version import parse
|
||||
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.config_repo import config
|
||||
|
||||
@@ -69,19 +71,10 @@ class TestService(TestCase, TestServiceInterface):
|
||||
delete_params=delete_params,
|
||||
)
|
||||
|
||||
def create_temp_version(self, *, client=None, **kwargs) -> str:
|
||||
return self._create_temp_helper(
|
||||
service="datasets",
|
||||
create_endpoint="create_version",
|
||||
delete_endpoint="delete_version",
|
||||
object_name="version",
|
||||
create_params=kwargs,
|
||||
client=client,
|
||||
)
|
||||
|
||||
def setUp(self, version="1.7"):
|
||||
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
|
||||
self._deferred = []
|
||||
self._version = parse(version)
|
||||
header(self.id())
|
||||
|
||||
def tearDown(self):
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
parameterized
|
||||
141
apiserver/tests/automated/test_batch_operations.py
Normal file
141
apiserver/tests/automated/test_batch_operations.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from apiserver.database.utils import id as db_id
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestBatchOperations(TestService):
|
||||
name = "batch operation test"
|
||||
comment = "this is a comment"
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_tasks(self):
|
||||
tasks = [self._temp_task() for _ in range(2)]
|
||||
models = [
|
||||
self._temp_task_model(task=t, uri=f"uri_{idx}")
|
||||
for idx, t in enumerate(tasks)
|
||||
]
|
||||
missing_id = db_id()
|
||||
ids = [*tasks, missing_id]
|
||||
|
||||
# enqueue
|
||||
res = self.api.tasks.enqueue_many(ids=ids)
|
||||
self._assert_succeeded(res, tasks)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.tasks.get_all_ex(id=ids).tasks
|
||||
self.assertEqual({t.status for t in data}, {"queued"})
|
||||
|
||||
# stop
|
||||
for t in tasks:
|
||||
self.api.tasks.started(task=t)
|
||||
res = self.api.tasks.stop_many(ids=ids)
|
||||
self._assert_succeeded(res, tasks)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.tasks.get_all_ex(id=ids).tasks
|
||||
self.assertEqual({t.status for t in data}, {"stopped"})
|
||||
|
||||
# publish
|
||||
res = self.api.tasks.publish_many(ids=ids, publish_model=False)
|
||||
self._assert_succeeded(res, tasks)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.tasks.get_all_ex(id=ids).tasks
|
||||
self.assertEqual({t.status for t in data}, {"published"})
|
||||
|
||||
# reset
|
||||
res = self.api.tasks.reset_many(
|
||||
ids=ids, delete_output_models=True, return_file_urls=True, force=True
|
||||
)
|
||||
self._assert_succeeded(res, tasks)
|
||||
self.assertEqual(sum(t.deleted_models for t in res.succeeded), 2)
|
||||
self.assertEqual(
|
||||
set(url for t in res.succeeded for url in t.urls.model_urls),
|
||||
{"uri_0", "uri_1"},
|
||||
)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.tasks.get_all_ex(id=ids).tasks
|
||||
self.assertEqual({t.status for t in data}, {"created"})
|
||||
|
||||
# archive/unarchive
|
||||
res = self.api.tasks.archive_many(ids=ids)
|
||||
self._assert_succeeded(res, tasks)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.tasks.get_all_ex(id=ids).tasks
|
||||
self.assertTrue(all("archived" in t.system_tags for t in data))
|
||||
|
||||
res = self.api.tasks.unarchive_many(ids=ids)
|
||||
self._assert_succeeded(res, tasks)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.tasks.get_all_ex(id=ids).tasks
|
||||
self.assertFalse(any("archived" in t.system_tags for t in data))
|
||||
|
||||
# delete
|
||||
res = self.api.tasks.delete_many(
|
||||
ids=ids, delete_output_models=True, return_file_urls=True
|
||||
)
|
||||
self._assert_succeeded(res, tasks)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.tasks.get_all_ex(id=ids).tasks
|
||||
self.assertEqual(data, [])
|
||||
|
||||
def test_models(self):
|
||||
uris = [f"file:///{i}" for i in range(2)]
|
||||
models = [self._temp_model(uri=uri) for uri in uris]
|
||||
missing_id = db_id()
|
||||
ids = [*models, missing_id]
|
||||
|
||||
# publish
|
||||
task = self._temp_task()
|
||||
self.api.models.edit(model=ids[0], ready=False, task=task)
|
||||
self.api.tasks.add_or_update_model(
|
||||
task=task, name="output", type="input", model=ids[0]
|
||||
)
|
||||
res = self.api.models.publish_many(
|
||||
ids=ids, publish_task=True, force_publish_task=True
|
||||
)
|
||||
self._assert_succeeded(res, [ids[0]])
|
||||
self.assertEqual(res.succeeded[0].published_task.id, task)
|
||||
self._assert_failed(res, [ids[1], missing_id])
|
||||
|
||||
# archive/unarchive
|
||||
res = self.api.models.archive_many(ids=ids)
|
||||
self._assert_succeeded(res, models)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.models.get_all_ex(id=ids).models
|
||||
self.assertTrue(all("archived" in m.system_tags for m in data))
|
||||
|
||||
res = self.api.models.unarchive_many(ids=ids)
|
||||
self._assert_succeeded(res, models)
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.models.get_all_ex(id=ids).models
|
||||
self.assertFalse(any("archived" in m.system_tags for m in data))
|
||||
|
||||
# delete
|
||||
res = self.api.models.delete_many(ids=[*models, missing_id], force=True)
|
||||
self._assert_succeeded(res, models)
|
||||
self.assertEqual(set(m.url for m in res.succeeded), set(uris))
|
||||
self._assert_failed(res, [missing_id])
|
||||
data = self.api.models.get_all_ex(id=ids).models
|
||||
self.assertEqual(data, [])
|
||||
|
||||
def _assert_succeeded(self, res, succeeded_ids):
|
||||
self.assertEqual(set(f.id for f in res.succeeded), set(succeeded_ids))
|
||||
|
||||
def _assert_failed(self, res, failed_ids):
|
||||
self.assertEqual(set(f.id for f in res.failed), set(failed_ids))
|
||||
|
||||
def _temp_model(self, **kwargs):
|
||||
self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={})
|
||||
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
|
||||
|
||||
def _temp_task(self):
|
||||
return self.create_temp(
|
||||
service="tasks", type="testing", name=self.name, input=dict(view={}),
|
||||
)
|
||||
|
||||
def _temp_task_model(self, task, **kwargs) -> str:
|
||||
model = self._temp_model(ready=False, task=task, **kwargs)
|
||||
self.api.tasks.add_or_update_model(
|
||||
task=task, name="output", type="output", model=model
|
||||
)
|
||||
return model
|
||||
@@ -28,7 +28,9 @@ class TestEntityOrdering(TestService):
|
||||
self._assertGetTasksWithOrdering(order_by="comment")
|
||||
|
||||
# sort by parameter which type is not part of db schema
|
||||
self._assertGetTasksWithOrdering(order_by="execution.parameters.test")
|
||||
self._assertGetTasksWithOrdering(
|
||||
order_by="execution.parameters.test", valid_order=False
|
||||
)
|
||||
|
||||
def test_order_with_paging(self):
|
||||
order_field = "started"
|
||||
@@ -97,7 +99,9 @@ class TestEntityOrdering(TestService):
|
||||
|
||||
return val
|
||||
|
||||
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
|
||||
def _assertGetTasksWithOrdering(
|
||||
self, order_by: str = None, valid_order=True, **kwargs
|
||||
):
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
only_fields=self.only_fields,
|
||||
order_by=[order_by] if isinstance(order_by, str) else order_by,
|
||||
@@ -105,14 +109,16 @@ class TestEntityOrdering(TestService):
|
||||
**kwargs,
|
||||
).tasks
|
||||
self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks))
|
||||
if order_by:
|
||||
if order_by and valid_order:
|
||||
# test that the output is correctly ordered
|
||||
field_name = order_by if not order_by.startswith("-") else order_by[1:]
|
||||
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
|
||||
field_vals = [
|
||||
self._get_value_for_path(t, field_name.split(".")) for t in tasks
|
||||
]
|
||||
self._assertSorted(
|
||||
field_vals,
|
||||
ascending=not order_by.startswith("-"),
|
||||
is_numeric=field_name.startswith("execution.parameters.")
|
||||
is_numeric=field_name.startswith("execution.parameters."),
|
||||
)
|
||||
|
||||
def _create_tasks(self):
|
||||
|
||||
74
apiserver/tests/automated/test_queue_model_metadata.py
Normal file
74
apiserver/tests/automated/test_queue_model_metadata.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestQueueAndModelMetadata(TestService):
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||
|
||||
def test_queue_metas(self):
|
||||
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
||||
self._test_meta_operations(
|
||||
service=self.api.queues, entity="queue", _id=queue_id
|
||||
)
|
||||
|
||||
def test_models_metas(self):
|
||||
service = self.api.models
|
||||
entity = "model"
|
||||
model_id = self._temp_model("TestMetadata", metadata=self.meta1)
|
||||
self._test_meta_operations(
|
||||
service=self.api.models, entity="model", _id=model_id
|
||||
)
|
||||
|
||||
model_id = self._temp_model("TestMetadata1")
|
||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||
|
||||
def _test_meta_operations(
|
||||
self, service: APIClient.Service, entity: str, _id: str,
|
||||
):
|
||||
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
||||
assert_meta(_id=_id, meta=self.meta1)
|
||||
|
||||
meta2 = [
|
||||
{"key": "test1", "type": "str", "value": "data1"},
|
||||
{"key": "test2", "type": "str", "value": "data2"},
|
||||
{"key": "test3", "type": "str", "value": "data3"},
|
||||
]
|
||||
service.update(**{entity: _id, "metadata": meta2})
|
||||
assert_meta(_id=_id, meta=meta2)
|
||||
|
||||
updates = [
|
||||
{"key": "test2", "type": "int", "value": "10"},
|
||||
{"key": "test3", "type": "int", "value": "20"},
|
||||
{"key": "test4", "type": "array", "value": "xxx,yyy"},
|
||||
{"key": "test5", "type": "array", "value": "zzz"},
|
||||
]
|
||||
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=[meta2[0], *updates])
|
||||
|
||||
res = service.delete_metadata(
|
||||
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=meta2[:1])
|
||||
|
||||
def _assertMeta(
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
|
||||
):
|
||||
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
||||
self.assertEqual(res.metadata, meta)
|
||||
|
||||
def _temp_queue(self, name, **kwargs):
|
||||
return self.create_temp("queues", name=name, **kwargs)
|
||||
|
||||
def _temp_model(self, name: str, **kwargs):
|
||||
return self.create_temp(
|
||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||
)
|
||||
267
apiserver/tests/automated/test_subprojects.py
Normal file
267
apiserver/tests/automated/test_subprojects.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional, Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import id as db_id
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestSubProjects(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_project_aggregations(self):
|
||||
"""This test requires user with user_auth_only... credentials in db"""
|
||||
user2_client = APIClient(
|
||||
api_key=config.get("apiclient.user_auth_only"),
|
||||
secret_key=config.get("apiclient.user_auth_only_secret"),
|
||||
base_url=f"http://localhost:8008/v2.13",
|
||||
)
|
||||
|
||||
child = self._temp_project(name="Aggregation/Pr1", client=user2_client)
|
||||
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
|
||||
child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
|
||||
self.assertEqual(child_project.parent.id, project)
|
||||
|
||||
user = self.api.users.get_current_user().user.id
|
||||
|
||||
# test aggregations on project with empty subprojects
|
||||
res = self.api.users.get_all_ex(active_in_projects=[project])
|
||||
self.assertEqual(res.users, [])
|
||||
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
|
||||
self.assertEqual(res.projects, [])
|
||||
res = self.api.models.get_frameworks(projects=[project])
|
||||
self.assertEqual(res.frameworks, [])
|
||||
res = self.api.tasks.get_types(projects=[project])
|
||||
self.assertEqual(res.types, [])
|
||||
res = self.api.projects.get_task_parents(projects=[project])
|
||||
self.assertEqual(res.parents, [])
|
||||
|
||||
# test aggregations with non-empty subprojects
|
||||
task1 = self._temp_task(project=child)
|
||||
self._temp_task(project=child, parent=task1)
|
||||
framework = "Test framework"
|
||||
self._temp_model(project=child, framework=framework)
|
||||
res = self.api.users.get_all_ex(active_in_projects=[project])
|
||||
self._assert_ids(res.users, [user])
|
||||
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
|
||||
self._assert_ids(res.projects, [project])
|
||||
res = self.api.projects.get_task_parents(projects=[project])
|
||||
self._assert_ids(res.parents, [task1])
|
||||
res = self.api.models.get_frameworks(projects=[project])
|
||||
self.assertEqual(res.frameworks, [framework])
|
||||
res = self.api.tasks.get_types(projects=[project])
|
||||
self.assertEqual(res.types, ["testing"])
|
||||
|
||||
def _assert_ids(self, actual: Sequence[dict], expected: Sequence[str]):
|
||||
self.assertEqual([a["id"] for a in actual], expected)
|
||||
|
||||
def test_project_operations(self):
|
||||
# create
|
||||
with self.api.raises(errors.bad_request.InvalidProjectName):
|
||||
self._temp_project(name="/")
|
||||
project1 = self._temp_project(name="Root1/Pr1")
|
||||
project1_child = self._temp_project(name="Root1/Pr1/Pr2")
|
||||
with self.api.raises(errors.bad_request.ExpectedUniqueData):
|
||||
self._temp_project(name="Root1/Pr1/Pr2")
|
||||
|
||||
# update
|
||||
with self.api.raises(errors.bad_request.CannotUpdateProjectLocation):
|
||||
self.api.projects.update(project=project1, name="Root2/Pr2")
|
||||
res = self.api.projects.update(project=project1, name="Root1/Pr2")
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.projects.get_by_id(project=project1_child)
|
||||
self.assertEqual(res.project.name, "Root1/Pr2/Pr2")
|
||||
|
||||
# move
|
||||
res = self.api.projects.move(project=project1, new_location="Root2")
|
||||
self.assertEqual(res.moved, 2)
|
||||
res = self.api.projects.get_by_id(project=project1_child)
|
||||
self.assertEqual(res.project.name, "Root2/Pr2/Pr2")
|
||||
|
||||
# merge
|
||||
project_with_task, (active, archived) = self._temp_project_with_tasks(
|
||||
"Root1/Pr3/Pr4"
|
||||
)
|
||||
project1_parent = self._getProjectParent(project1)
|
||||
self._assertTags(project1_parent, tags=[], system_tags=[])
|
||||
self._assertTags(project1_parent, tags=[], system_tags=[])
|
||||
project_with_task_parent = self._getProjectParent(project_with_task)
|
||||
self._assertTags(project_with_task_parent)
|
||||
# self._assertTags(project_id=None)
|
||||
|
||||
merge_source = self.api.projects.get_by_id(
|
||||
project=project_with_task
|
||||
).project.parent
|
||||
res = self.api.projects.merge(
|
||||
project=merge_source, destination_project=project1
|
||||
)
|
||||
self.assertEqual(res.moved_entities, 0)
|
||||
self.assertEqual(res.moved_projects, 1)
|
||||
res = self.api.projects.get_by_id(project=project_with_task)
|
||||
self.assertEqual(res.project.name, "Root2/Pr2/Pr4")
|
||||
with self.api.raises(errors.bad_request.InvalidProjectId):
|
||||
self.api.projects.get_by_id(project=merge_source)
|
||||
|
||||
self._assertTags(project1_parent)
|
||||
self._assertTags(project1)
|
||||
self._assertTags(project_with_task_parent, tags=[], system_tags=[])
|
||||
# self._assertTags(project_id=None)
|
||||
|
||||
# delete
|
||||
with self.api.raises(errors.bad_request.ProjectHasTasks):
|
||||
self.api.projects.delete(project=project1)
|
||||
res = self.api.projects.delete(project=project1, force=True)
|
||||
self.assertEqual(res.deleted, 3)
|
||||
self.assertEqual(res.disassociated_tasks, 2)
|
||||
res = self.api.tasks.get_by_id(task=active).task
|
||||
self.assertIsNone(res.get("project"))
|
||||
for p_id in (project1, project1_child, project_with_task):
|
||||
with self.api.raises(errors.bad_request.InvalidProjectId):
|
||||
self.api.projects.get_by_id(project=p_id)
|
||||
|
||||
self._assertTags(project1_parent, tags=[], system_tags=[])
|
||||
# self._assertTags(project_id=None, tags=[], system_tags=[])
|
||||
|
||||
def _getProjectParent(self, project_id: str):
|
||||
return self.api.projects.get_all_ex(id=[project_id]).projects[0].parent.id
|
||||
|
||||
def _assertTags(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
tags: Sequence[str] = ("test",),
|
||||
system_tags: Sequence[str] = (EntityVisibility.archived.value,),
|
||||
):
|
||||
if project_id:
|
||||
res = self.api.projects.get_task_tags(
|
||||
projects=[project_id], include_system=True
|
||||
)
|
||||
else:
|
||||
res = self.api.organization.get_tags(include_system=True)
|
||||
|
||||
self.assertEqual(set(res.tags), set(tags))
|
||||
self.assertEqual(set(res.system_tags), set(system_tags))
|
||||
|
||||
def test_get_all_search_options(self):
|
||||
project1 = self._temp_project(name="project1")
|
||||
project2 = self._temp_project(name="project1/project2")
|
||||
self._temp_project(name="project3")
|
||||
|
||||
# local search finds only at the specified level
|
||||
res = self.api.projects.get_all_ex(
|
||||
name="project1", shallow_search=True
|
||||
).projects
|
||||
self.assertEqual([p.id for p in res], [project1])
|
||||
res = self.api.projects.get_all_ex(name="project1", parent=[project1]).projects
|
||||
self.assertEqual([p.id for p in res], [project2])
|
||||
|
||||
# global search finds all or below the specified level
|
||||
res = self.api.projects.get_all_ex(name="project1").projects
|
||||
self.assertEqual(set(p.id for p in res), {project1, project2})
|
||||
project4 = self._temp_project(name="project1/project2/project1")
|
||||
res = self.api.projects.get_all_ex(name="project1", parent=[project2]).projects
|
||||
self.assertEqual([p.id for p in res], [project4])
|
||||
|
||||
self.api.projects.delete(project=project1, force=True)
|
||||
|
||||
def test_get_all_with_check_own_contents(self):
|
||||
project1, _ = self._temp_project_with_tasks(name="project1x")
|
||||
project2 = self._temp_project(name="project2x")
|
||||
self._temp_project_with_tasks(name="project2x/project22")
|
||||
self._temp_model(project=project1)
|
||||
|
||||
res = self.api.projects.get_all_ex(
|
||||
id=[project1, project2], check_own_contents=True
|
||||
).projects
|
||||
res1 = next(p for p in res if p.id == project1)
|
||||
self.assertEqual(res1.own_tasks, 2)
|
||||
self.assertEqual(res1.own_models, 1)
|
||||
|
||||
res2 = next(p for p in res if p.id == project2)
|
||||
self.assertEqual(res2.own_tasks, 0)
|
||||
self.assertEqual(res2.own_models, 0)
|
||||
|
||||
def test_get_all_with_stats(self):
|
||||
project4, _ = self._temp_project_with_tasks(name="project1/project3/project4")
|
||||
project5, _ = self._temp_project_with_tasks(name="project1/project3/project5")
|
||||
project2 = self._temp_project(name="project2")
|
||||
res = self.api.projects.get_all(shallow_search=True).projects
|
||||
self.assertTrue(any(p for p in res if p.id == project2))
|
||||
self.assertFalse(any(p for p in res if p.id in [project4, project5]))
|
||||
|
||||
project1 = first(p.id for p in res if p.name == "project1")
|
||||
res = self.api.projects.get_all_ex(
|
||||
id=[project1, project2], include_stats=True
|
||||
).projects
|
||||
self.assertEqual(set(p.id for p in res), {project1, project2})
|
||||
res1 = next(p for p in res if p.id == project1)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||
self.assertEqual(
|
||||
{sp.name for sp in res1.sub_projects},
|
||||
{
|
||||
"project1/project3",
|
||||
"project1/project3/project4",
|
||||
"project1/project3/project5",
|
||||
},
|
||||
)
|
||||
res2 = next(p for p in res if p.id == project2)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||
self.assertEqual(res2.sub_projects, [])
|
||||
|
||||
def _run_tasks(self, *tasks):
|
||||
"""Imitate 1 second of running"""
|
||||
for task_id in tasks:
|
||||
self.api.tasks.started(task=task_id)
|
||||
sleep(1)
|
||||
for task_id in tasks:
|
||||
self.api.tasks.stopped(task=task_id)
|
||||
|
||||
def _temp_project_with_tasks(self, name) -> Tuple[str, Tuple[str, str]]:
|
||||
pr_id = self._temp_project(name=name)
|
||||
task_active = self._temp_task(project=pr_id)
|
||||
task_archived = self._temp_task(
|
||||
project=pr_id, system_tags=[EntityVisibility.archived.value], tags=["test"]
|
||||
)
|
||||
self._run_tasks(task_active, task_archived)
|
||||
return pr_id, (task_active, task_archived)
|
||||
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def _temp_project(self, name, client=None, **kwargs):
|
||||
return self.create_temp(
|
||||
"projects",
|
||||
delete_params=self.delete_params,
|
||||
name=name,
|
||||
description="",
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
delete_params=self.delete_params,
|
||||
type="testing",
|
||||
name=db_id(),
|
||||
input=dict(view=dict()),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_model(self, **kwargs):
|
||||
return self.create_temp(
|
||||
service="models",
|
||||
delete_params=self.delete_params,
|
||||
name="test",
|
||||
uri="file:///a",
|
||||
labels={},
|
||||
**kwargs,
|
||||
)
|
||||
@@ -47,9 +47,9 @@ class TestTasksArtifacts(TestService):
|
||||
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
|
||||
|
||||
new_artifacts = [
|
||||
dict(key="x", type="str", uri="x_test"),
|
||||
dict(key="y", type="int", uri="y_test"),
|
||||
dict(key="z", type="int", uri="y_test"),
|
||||
dict(key="x", type="str", uri="x_test", mode="input"),
|
||||
dict(key="y", type="int", uri="y_test", mode="input"),
|
||||
dict(key="z", type="int", uri="y_test", mode="input"),
|
||||
]
|
||||
new_task = self.api.tasks.clone(
|
||||
task=task, execution_overrides={"artifacts": new_artifacts}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from typing import Sequence, Mapping
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.tests.automated import TestService
|
||||
@@ -128,6 +127,96 @@ class TestTaskDebugImages(TestService):
|
||||
|
||||
def test_task_debug_images(self):
|
||||
task = self._temp_task()
|
||||
|
||||
# test empty
|
||||
res = self.api.events.debug_images(metrics=[{"task": task}], iters=5)
|
||||
self.assertFalse(res.metrics[0].iterations)
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
|
||||
)
|
||||
self.assertFalse(res.metrics[0].iterations)
|
||||
|
||||
# test not empty
|
||||
metrics = {
|
||||
"Metric1": ["Variant1", "Variant2"],
|
||||
"Metric2": ["Variant3", "Variant4"],
|
||||
}
|
||||
events = [
|
||||
self._create_task_event(
|
||||
task=task,
|
||||
iteration=1,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
url=f"{metric}_{variant}_{1}",
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for variant in variants
|
||||
]
|
||||
self.send_batch(events)
|
||||
scroll_id = self._assertTaskMetrics(
|
||||
task=task, expected_metrics=metrics, iterations=1
|
||||
)
|
||||
|
||||
# test refresh
|
||||
update = {
|
||||
"Metric2": ["Variant3", "Variant4", "Variant5"],
|
||||
"Metric3": ["VariantA", "VariantB"],
|
||||
}
|
||||
events = [
|
||||
self._create_task_event(
|
||||
task=task,
|
||||
iteration=2,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
url=f"{metric}_{variant}_{2}",
|
||||
)
|
||||
for metric, variants in update.items()
|
||||
for variant in variants
|
||||
]
|
||||
self.send_batch(events)
|
||||
# without refresh the metric states are not updated
|
||||
scroll_id = self._assertTaskMetrics(
|
||||
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
|
||||
)
|
||||
|
||||
# with refresh there are new metrics and existing ones are updated
|
||||
self._assertTaskMetrics(
|
||||
task=task,
|
||||
expected_metrics=update,
|
||||
iterations=1,
|
||||
scroll_id=scroll_id,
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
def _assertTaskMetrics(
|
||||
self,
|
||||
task: str,
|
||||
expected_metrics: Mapping[str, Sequence[str]],
|
||||
iterations,
|
||||
scroll_id: str = None,
|
||||
refresh=False,
|
||||
) -> str:
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
|
||||
)
|
||||
if not iterations:
|
||||
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
||||
return res.scroll_id
|
||||
|
||||
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
|
||||
for metric_data in res.metrics:
|
||||
self.assertEqual(len(metric_data.iterations), iterations)
|
||||
for it_data in metric_data.iterations:
|
||||
self.assertEqual(
|
||||
set((e.metric, e.variant) for e in it_data.events), expected_variants
|
||||
)
|
||||
|
||||
return res.scroll_id
|
||||
|
||||
def test_get_debug_images_navigation(self):
|
||||
task = self._temp_task()
|
||||
metric = "Metric1"
|
||||
variants = [("Variant1", 7), ("Variant2", 4)]
|
||||
iterations = 10
|
||||
@@ -136,7 +225,7 @@ class TestTaskDebugImages(TestService):
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}], iters=5,
|
||||
)
|
||||
self.assertFalse(res.metrics)
|
||||
self.assertFalse(res.metrics[0].iterations)
|
||||
|
||||
# create events
|
||||
events = [
|
||||
@@ -195,7 +284,7 @@ class TestTaskDebugImages(TestService):
|
||||
expected_page: int,
|
||||
iters: int = 5,
|
||||
**extra_params,
|
||||
):
|
||||
) -> str:
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}],
|
||||
iters=iters,
|
||||
@@ -204,7 +293,6 @@ class TestTaskDebugImages(TestService):
|
||||
)
|
||||
data = res["metrics"][0]
|
||||
self.assertEqual(data["task"], task)
|
||||
self.assertEqual(data["metric"], metric)
|
||||
left_iterations = max(0, max(unique_images) - expected_page * iters)
|
||||
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
|
||||
for it in data["iterations"]:
|
||||
|
||||
@@ -204,7 +204,9 @@ class TestTaskEvents(TestService):
|
||||
self.send_batch(events)
|
||||
for key in None, "iter", "timestamp", "iso_time":
|
||||
with self.subTest(key=key):
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(
|
||||
task=task, **(dict(key=key) if key is not None else {})
|
||||
)
|
||||
self.assertIn(metric, data)
|
||||
self.assertIn(variant, data[metric])
|
||||
self.assertIn("x", data[metric][variant])
|
||||
|
||||
@@ -105,7 +105,9 @@ class TestTasksHyperparams(TestService):
|
||||
)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id
|
||||
new_task = self.api.tasks.clone(
|
||||
task=task, new_task_hyperparams=new_params_dict
|
||||
).id
|
||||
try:
|
||||
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
|
||||
self.assertEqual(new_params, res.hyperparams)
|
||||
@@ -123,7 +125,9 @@ class TestTasksHyperparams(TestService):
|
||||
task=task, hyperparams=[dict(section="test")]
|
||||
)
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True
|
||||
task=task,
|
||||
hyperparams=[dict(section="test", name="x", value="123")],
|
||||
force=True,
|
||||
)
|
||||
self.api.tasks.delete_hyper_params(
|
||||
task=task, hyperparams=[dict(section="test")], force=True
|
||||
@@ -146,7 +150,12 @@ class TestTasksHyperparams(TestService):
|
||||
return [
|
||||
dict(section="Args", name=k, value=str(v), type="legacy")
|
||||
if not k.startswith("TF_DEFINE/")
|
||||
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
|
||||
else dict(
|
||||
section="TF_DEFINE",
|
||||
name=k[len("TF_DEFINE/") :],
|
||||
value=str(v),
|
||||
type="legacy",
|
||||
)
|
||||
for k, v in legacy.items()
|
||||
]
|
||||
|
||||
@@ -168,6 +177,7 @@ class TestTasksHyperparams(TestService):
|
||||
new_config = [
|
||||
dict(name="param$1", type="type1", value="10"),
|
||||
dict(name="param/2", type="type1", value="20"),
|
||||
dict(name="param_empty", type="type1", value=""),
|
||||
]
|
||||
new_config_dict = self._config_dict_from_list(new_config)
|
||||
task, _ = self.new_task(
|
||||
@@ -188,7 +198,14 @@ class TestTasksHyperparams(TestService):
|
||||
# names
|
||||
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
|
||||
self.assertEqual(task, res.task)
|
||||
self.assertEqual(["design", "param$1", "param/2"], res.names)
|
||||
self.assertEqual(
|
||||
["design", *[c["name"] for c in new_config if c["value"]]], res.names
|
||||
)
|
||||
res = self.api.tasks.get_configuration_names(
|
||||
tasks=[task], skip_empty=False
|
||||
).configurations[0]
|
||||
self.assertEqual(task, res.task)
|
||||
self.assertEqual(["design", *[c["name"] for c in new_config]], res.names)
|
||||
|
||||
# returned as one list with names filtering
|
||||
res = self.api.tasks.get_configurations(
|
||||
@@ -216,14 +233,14 @@ class TestTasksHyperparams(TestService):
|
||||
|
||||
# delete
|
||||
new_to_delete = self._get_config_keys(new_config[1:])
|
||||
self.api.tasks.delete_configuration(
|
||||
task=task, configuration=new_to_delete
|
||||
)
|
||||
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(old_config + new_config[:1], res.configuration)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id
|
||||
new_task = self.api.tasks.clone(
|
||||
task=task, new_task_configuration=new_config_dict
|
||||
).id
|
||||
try:
|
||||
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
|
||||
self.assertEqual(new_config, res.configuration)
|
||||
@@ -233,13 +250,9 @@ class TestTasksHyperparams(TestService):
|
||||
# edit/delete of running task
|
||||
self.api.tasks.started(task=task)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.edit_configuration(
|
||||
task=task, configuration=new_config
|
||||
)
|
||||
self.api.tasks.edit_configuration(task=task, configuration=new_config)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.delete_configuration(
|
||||
task=task, configuration=new_to_delete
|
||||
)
|
||||
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
|
||||
self.api.tasks.edit_configuration(
|
||||
task=task, configuration=new_config, force=True
|
||||
)
|
||||
@@ -292,7 +305,9 @@ class TestTasksHyperparams(TestService):
|
||||
task_id, _ = self.new_task(
|
||||
execution={"parameters": legacy_params, "model_desc": legacy_config}
|
||||
)
|
||||
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
|
||||
config = self._config_dict_from_list(
|
||||
self._new_config_from_legacy(legacy_config)
|
||||
)
|
||||
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
|
||||
|
||||
old_api = APIClient(base_url="http://localhost:8008/v2.8")
|
||||
@@ -304,7 +319,10 @@ class TestTasksHyperparams(TestService):
|
||||
|
||||
modified_params = {"legacy.2": "val2"}
|
||||
modified_config = {"design": "by"}
|
||||
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
|
||||
old_api.tasks.edit(
|
||||
task=task_id,
|
||||
execution=dict(parameters=modified_params, model_desc=modified_config),
|
||||
)
|
||||
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(modified_params, task.execution.parameters)
|
||||
self.assertEqual(modified_config, task.execution.model_desc)
|
||||
|
||||
114
apiserver/tests/automated/test_task_models.py
Normal file
114
apiserver/tests/automated/test_task_models.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from copy import deepcopy
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from packaging.version import parse
|
||||
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskModels(TestService):
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_new_apis(self):
|
||||
# no models
|
||||
empty_task = self.new_task()
|
||||
self.assertModels(empty_task, [], [])
|
||||
|
||||
id1, id2 = self.new_model("model1"), self.new_model("model2")
|
||||
input_models = [
|
||||
{"name": "input1", "model": id1},
|
||||
{"name": "input2", "model": id2},
|
||||
]
|
||||
output_models = [
|
||||
{"name": "output1", "model": "id3"},
|
||||
{"name": "output2", "model": "id4"},
|
||||
]
|
||||
|
||||
# task creation with models
|
||||
task = self.new_task(models={"input": input_models, "output": output_models})
|
||||
self.assertModels(task, input_models, output_models)
|
||||
|
||||
# add_or_update existing model
|
||||
res = self.api.tasks.add_or_update_model(
|
||||
task=task, name="input1", type="input", model="Test"
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
modified_input = deepcopy(input_models)
|
||||
modified_input[0]["model"] = "Test"
|
||||
self.assertModels(task, modified_input, output_models)
|
||||
|
||||
# add_or_update new mode
|
||||
res = self.api.tasks.add_or_update_model(
|
||||
task=task, name="output3", type="output", model="TestOutput"
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
modified_output = deepcopy(output_models)
|
||||
modified_output.append({"name": "output3", "model": "TestOutput"})
|
||||
self.assertModels(task, modified_input, modified_output)
|
||||
|
||||
# task editing
|
||||
self.api.tasks.edit(
|
||||
task=task, models={"input": input_models, "output": output_models}
|
||||
)
|
||||
self.assertModels(task, input_models, output_models)
|
||||
|
||||
# delete models
|
||||
res = self.api.tasks.delete_models(
|
||||
task=task,
|
||||
models=[
|
||||
{"name": "input1", "type": "input"},
|
||||
{"name": "input2", "type": "input"},
|
||||
{"name": "output1", "type": "output"},
|
||||
{"name": "not_existing", "type": "output"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
self.assertModels(task, [], output_models[1:])
|
||||
|
||||
def assertModels(
|
||||
self, task_id: str, input_models: Sequence[dict], output_models: Sequence[dict],
|
||||
):
|
||||
def get_model_id(model: dict) -> Optional[str]:
|
||||
if not model:
|
||||
return None
|
||||
id_ = model.get("model")
|
||||
if isinstance(id_, str):
|
||||
return id_
|
||||
if id_ is None or id_ == {}:
|
||||
return None
|
||||
return id_.get("id")
|
||||
|
||||
def compare_models(actual: Sequence[dict], expected: Sequence[dict]):
|
||||
self.assertEqual(
|
||||
[(m["name"], get_model_id(m)) for m in actual],
|
||||
[(m["name"], m["model"]) for m in expected],
|
||||
)
|
||||
|
||||
for task in (
|
||||
self.api.tasks.get_all_ex(id=task_id).tasks[0],
|
||||
self.api.tasks.get_all(id=task_id).tasks[0],
|
||||
self.api.tasks.get_by_id(task=task_id).task,
|
||||
):
|
||||
compare_models(task.models.input, input_models)
|
||||
compare_models(task.models.output, output_models)
|
||||
if self._version < parse("2.13"):
|
||||
self.assertEqual(
|
||||
get_model_id(task.execution),
|
||||
input_models[0]["model"] if input_models else None,
|
||||
)
|
||||
self.assertEqual(
|
||||
get_model_id(task.output),
|
||||
output_models[-1]["model"] if output_models else None,
|
||||
)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs, type="testing", name="test task models", input=dict(view=dict())
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def new_model(self, name: str, **kwargs):
|
||||
return self.create_temp(
|
||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user