Compare commits

81 Commits

Author SHA1 Message Date
allegroai
2216bfe875 Version bump 2021-05-11 16:12:48 +03:00
allegroai
9beefa7473 Add missing login.logout endpoint 2021-05-11 16:12:27 +03:00
allegroai
8ebc334889 Fix broken config dir backwards compatibility (/opt/trains/config should still be supported) 2021-05-11 16:12:13 +03:00
allegroai
e662c850af Update config file in docs 2021-05-04 11:07:38 +03:00
allegroai
1e5163e530 Upgrade jinja2 version due to CVE-2020-28493 2021-05-03 23:23:06 +03:00
allegroai
1567774765 Version bump 2021-05-03 18:20:32 +03:00
allegroai
babfcbb707 Update migration script 2021-05-03 18:15:43 +03:00
allegroai
027edd86bb Fix actual file path reported in error/success message 2021-05-03 18:14:56 +03:00
allegroai
cc83aadae6 Fix file delete (bad merge) 2021-05-03 18:14:30 +03:00
allegroai
8c18660a82 Fix inconsistency in accessing files between download and delete 2021-05-03 18:14:08 +03:00
allegroai
4fe61ee25c Fix running migration scripts calling other files 2021-05-03 18:13:49 +03:00
allegroai
e18b21639c Fix regex query for fields containing "_" 2021-05-03 18:13:00 +03:00
allegroai
1cef03b8c2 Add check_contents flag for projects.get_all_ex 2021-05-03 18:12:44 +03:00
allegroai
d60d6dfe99 Move to clearml in docker-compose files 2021-05-03 18:12:21 +03:00
allegroai
27d086bca2 Fix schema for Task.runtime
Add infrastructure for API calls limits handling
2021-05-03 18:11:46 +03:00
allegroai
add3f011a0 Add runtime to tasks.edit 2021-05-03 18:10:48 +03:00
allegroai
ee90b0b024 Remove "Auto-generated while cloning" project description 2021-05-03 18:10:32 +03:00
allegroai
9bf107866f Fix crash in models publish_many without model task 2021-05-03 18:10:09 +03:00
allegroai
4d2f282950 Add Model.last_update to schema 2021-05-03 18:09:54 +03:00
allegroai
b55fad1b59 Remove "Auto-generated during move" project description 2021-05-03 18:09:31 +03:00
allegroai
ba77ff11e9 Fix missing custom metric values turn up first in sorting 2021-05-03 18:08:39 +03:00
allegroai
b67aa05d6f Return results per task iterations in debug images request 2021-05-03 18:08:14 +03:00
allegroai
6b0c45a861 Fix batch operations results 2021-05-03 18:07:37 +03:00
allegroai
dc9623e964 Fix docker_cmd projection in backwards compatibility
Fix support to clear input/output models and docker_cmd in backwards compatibility mode
Fix schema
2021-05-03 18:06:39 +03:00
allegroai
3d73d60826 Better handling of invalid iterations on add_batch 2021-05-03 18:05:24 +03:00
allegroai
9f0c9c3690 Fix open ranges 2021-05-03 18:05:03 +03:00
allegroai
1a3d3494ce Fix numeric locale 2021-05-03 18:04:45 +03:00
allegroai
b99f620073 Added unarchive APIs 2021-05-03 18:04:17 +03:00
allegroai
e2f265b4bc Unify batch operations 2021-05-03 18:03:54 +03:00
allegroai
251ee57ffd Fix rapidjson dumps does not support ensure_ascii, only Encoder initialization does
Add task enqueue status
2021-05-03 18:03:17 +03:00
allegroai
7e03104f1c Add Model last_update field 2021-05-03 18:02:25 +03:00
allegroai
f1a258208e Disable backwards compatibility for 2.13 clients 2021-05-03 18:01:59 +03:00
allegroai
66cc49313b Fix schema 2021-05-03 18:01:29 +03:00
allegroai
9ae2943f7d Fix crash in tasks.reset 2021-05-03 17:59:44 +03:00
allegroai
54326f707b Add JSON flags support to APICall 2021-05-03 17:58:57 +03:00
allegroai
3a3b57c15f Support mongodb authentication 2021-05-03 17:57:53 +03:00
allegroai
8ea8ad34e6 Remove collecting task output models from Models collection during migration 2021-05-03 17:57:27 +03:00
allegroai
179661a0d4 Rename default input and output models
Better handling of backwards compatibility in task models
Code cleanup
2021-05-03 17:56:50 +03:00
allegroai
3d22ca1888 Escape task.container and task.execution.model_labels fields in DB 2021-05-03 17:56:17 +03:00
allegroai
fdf6798d0c Don't unset Task's execution.queue on dequeue 2021-05-03 17:54:16 +03:00
allegroai
9d9a44b927 Add skip_empty parameter in get_configuration_names 2021-05-03 17:53:56 +03:00
allegroai
dad935e81d Remove webserver project 2021-05-03 17:53:24 +03:00
allegroai
a75534ec34 Add batch operations support 2021-05-03 17:52:54 +03:00
allegroai
eab33de97e Add bcrypt support to fixed user password 2021-05-03 17:52:25 +03:00
allegroai
29de110abb Add support for queue and model metadata 2021-05-03 17:50:25 +03:00
allegroai
2e7f418ee2 Fix Task.container backwards-compatibility
Fix sub-projects
2021-05-03 17:49:48 +03:00
allegroai
dadb996d22 Refactor es_factory to better support override host/port 2021-05-03 17:48:41 +03:00
allegroai
174f692edf Code cleanup 2021-05-03 17:48:24 +03:00
allegroai
f4d5168a20 Add Task.container support 2021-05-03 17:48:01 +03:00
allegroai
5a438e8435 Fix projects.move 2021-05-03 17:47:11 +03:00
allegroai
ce4814dc47 Add field override support in config (using "-" prefix) 2021-05-03 17:46:36 +03:00
allegroai
ef42d0265d Add multi-models support 2021-05-03 17:46:00 +03:00
allegroai
3c5195028e More sub-projects support and fixes 2021-05-03 17:44:54 +03:00
allegroai
0d5174c453 Support iterating over all task metrics in task debug images 2021-05-03 17:43:02 +03:00
allegroai
c034c1a986 Add sub-projects support 2021-05-03 17:42:10 +03:00
allegroai
1b49da8748 Revoke tests account in fixed mode, cleanup 2021-05-03 17:40:41 +03:00
allegroai
26bda01a28 Add missing errors 2021-05-03 17:39:49 +03:00
allegroai
f5008d80ad Optimize and improve tasks/models/projects.delete 2021-05-03 17:39:13 +03:00
allegroai
8b464e7ae6 Return file urls for tasks.delete/reset and models.delete 2021-05-03 17:38:09 +03:00
allegroai
78e4a58c91 Fix API enum fields and add last_iteration to range queries 2021-05-03 17:37:49 +03:00
allegroai
7a4a5eb03e Fix dropping index by name during the migration fails if the index does not exist 2021-05-03 17:36:49 +03:00
allegroai
d029d56508 Support active users in projects 2021-05-03 17:36:04 +03:00
allegroai
6411954002 Improve visibility for distributed lock hanging 2021-05-03 17:35:17 +03:00
allegroai
7f4ad0d1ca Support projects.get_hyperparam_values 2021-05-03 17:34:40 +03:00
allegroai
4cd4b2914d Add range queries
Switch from sematic_version to packaging.version in db migrations
2021-05-03 17:33:47 +03:00
allegroai
1d55710a0b Update max API version 2021-05-03 17:33:12 +03:00
allegroai
8f646043bb Allow enqueueing stopped tasks
More clearml stuff
2021-05-03 17:31:02 +03:00
allegroai
4b11a6efcd Move apiserver to clearml 2021-05-03 17:26:44 +03:00
allegroai
cb3a7c90a8 Move fileserver to clearml 2021-05-03 17:00:38 +03:00
allegroai
074842a122 Improve fileserver delete code 2021-05-03 16:58:11 +03:00
allegroai
749ff4a44f Fix Tasks.reset does not mark children's parent as deleted 2021-05-03 16:57:06 +03:00
allegroai
7d6918ecb0 Fix large plots comparison 2021-05-03 16:55:59 +03:00
allegroai
47184c2833 Fix querying by task parent 2021-05-03 16:55:03 +03:00
allegroai
6434f1028e Update docker-compose files 2021-01-14 12:37:25 +02:00
allegroai
daade08940 Update docker-compose-win10.yml
Remove deprecated docker-compose-unified.yml
2021-01-07 00:21:24 +02:00
Allegro AI
a1d289822f Update docker-compose-unified.yml
Reduce ES watermark
2021-01-06 17:46:09 +02:00
Allegro AI
1ce34f2c74 Update docker-compose-win10.yml
Reduce ES watermark
2021-01-06 17:45:27 +02:00
Allegro AI
c2dc73a71f Update docker-compose.yml
Reduce ES watermark
2021-01-06 17:44:45 +02:00
allegroai
07bb3b5df8 Update README 2021-01-06 00:32:52 +02:00
allegroai
067ef82576 Update README 2021-01-05 22:56:43 +02:00
allegroai
59fc98e0c4 Upgrade Jinja2 version (vulnerability found in older versions) 2021-01-05 20:18:09 +02:00
123 changed files with 7362 additions and 2848 deletions

View File

@@ -20,7 +20,7 @@
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
Follow [this procedure](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_es7_migration.html) to migrate existing data.
---
@@ -78,15 +78,15 @@ For example, to see if port `8080` is in use:
Launch The **ClearML Server** in any of the following formats:
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
- Pre-built [AWS EC2 AMI](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_aws_ec2_ami.html)
- Pre-built [GCP Custom Image](hhttps://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_gcp.html)
- Pre-built Docker Image
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
- [Linux](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
- [macOS](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
- [Windows 10](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_win.html)
- Kubernetes
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
- [Kubernetes Helm](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes_helm.html)
- Manual [Kubernetes installation](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes.html)
## Connecting ClearML to your ClearML Server
@@ -211,9 +211,9 @@ To upgrade your existing **ClearML Server** deployment:
## Community & Support
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
Additionally, you can always find us at *clearml@allegro.ai*

View File

@@ -1,3 +1,8 @@
301 {
_: "moved_permanently"
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
}
400 {
_: "bad_request"
1: ["not_supported", "endpoint is not supported"]
@@ -62,6 +67,10 @@
402: ["project_has_tasks", "project has associated tasks"]
403: ["project_not_found", "project not found"]
405: ["project_has_models", "project has associated models"]
407: ["invalid_project_name", "invalid project name"]
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]
@@ -108,6 +117,11 @@
21: ["no_write_permission", "forbidden (modification not allowed)"]
}
410: {
_: "gone"
1: ["not_supported", "thus endpoint is not supported any more"]
}
500 {
_: "server_error"
0: ["general_error", "general server error"]

View File

@@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
)
def parse_value(self, value):
if value is None and not self.required:
if value is NotSet and not self.required:
return self.get_default_value()
try:
# noinspection PyArgumentList

View File

@@ -0,0 +1,25 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apiserver.apimodels import ListField
from apiserver.apimodels.base import UpdateResponse
class BatchRequest(Base):
ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
class BatchResponse(Base):
succeeded: Sequence[dict] = ListField([dict])
failed: Sequence[dict] = ListField([dict])
class UpdateBatchItem(UpdateResponse):
id: str = StringField()
class UpdateBatchResponse(BatchResponse):
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)

View File

@@ -38,7 +38,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
metric: str = StringField(default=None)
class DebugImagesRequest(Base):
@@ -89,7 +89,6 @@ class IterationEvents(Base):
class MetricEvents(Base):
task: str = StringField()
metric: str = StringField()
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)

View File

@@ -31,3 +31,4 @@ class GetSupportedModesResponse(Base):
server_errors = EmbeddedField(ServerErrors)
sso = DictField([str, type(None)])
sso_providers = ListField([dict])
authenticated = BoolField(default=False)

View File

@@ -0,0 +1,23 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
class MetadataItem(Base):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
class DeleteMetadata(Base):
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
class AddOrUpdateMetadata(Base):
metadata: Sequence[MetadataItem] = ListField(
[MetadataItem], validators=validators.Length(minimum_value=1)
)

View File

@@ -3,7 +3,12 @@ from six import string_types
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
from apiserver.apimodels.batch import BatchRequest
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
class GetFrameworksRequest(models.Base):
@@ -13,7 +18,7 @@ class GetFrameworksRequest(models.Base):
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)
labels = DictField(value_types=string_types+(int,))
labels = DictField(value_types=string_types + (int,))
tags = ListField(items_types=string_types)
system_tags = ListField(items_types=string_types)
comment = fields.StringField()
@@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
metadata = ListField(items_types=[MetadataItem])
class CreateModelResponse(models.Base):
@@ -32,17 +38,40 @@ class CreateModelResponse(models.Base):
created = fields.BoolField(required=True)
class PublishModelRequest(models.Base):
class ModelRequest(models.Base):
model = fields.StringField(required=True)
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
class PublishModelRequest(ModelRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class ModelTaskPublishResponse(models.Base):
id = fields.StringField(required=True)
data = fields.EmbeddedField(TaskPublishResponse)
data = fields.EmbeddedField(UpdateResponse)
class PublishModelResponse(UpdateResponse):
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
updated = fields.IntField()
class ModelsPublishManyRequest(BatchRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class DeleteMetadataRequest(DeleteMetadata):
model = fields.StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True)

View File

@@ -5,11 +5,29 @@ from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
class ProjectReq(models.Base):
class ProjectRequest(models.Base):
project = fields.StringField(required=True)
class MergeRequest(ProjectRequest):
destination_project = fields.StringField()
class MoveRequest(ProjectRequest):
new_location = fields.StringField()
class DeleteRequest(ProjectRequest):
force = fields.BoolField(default=False)
delete_contents = fields.BoolField(default=False)
class ProjectOrNoneRequest(models.Base):
project = fields.StringField()
include_subprojects = fields.BoolField(default=True)
class GetHyperParamReq(ProjectReq):
class GetHyperParamRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@@ -18,7 +36,25 @@ class ProjectTagsRequest(TagsRequest):
projects = ListField(str)
class ProjectTaskParentsRequest(ProjectReq):
projects = ListField(str)
class MultiProjectRequest(models.Base):
projects = fields.ListField(str)
include_subprojects = fields.BoolField(default=True)
class ProjectTaskParentsRequest(MultiProjectRequest):
tasks_state = ActualEnumField(EntityVisibility)
class ProjectHyperparamValuesRequest(MultiProjectRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)

View File

@@ -3,6 +3,11 @@ from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
class GetDefaultResp(Base):
@@ -14,6 +19,7 @@ class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
class QueueRequest(Base):
@@ -28,6 +34,7 @@ class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
class TaskRequest(QueueRequest):
@@ -58,3 +65,11 @@ class QueueMetrics(Base):
class GetMetricsResponse(Base):
queues = ListField(QueueMetrics)
class DeleteMetadataRequest(DeleteMetadata):
queue = StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
queue = StringField(required=True)

View File

@@ -1,16 +1,17 @@
from typing import Sequence
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
from apiserver.database.model.task.task import (
TaskType,
ArtifactModes,
DEFAULT_ARTIFACT_MODE,
TaskModelTypes,
)
from apiserver.database.utils import get_options
@@ -43,26 +44,54 @@ class EnqueueResponse(UpdateResponse):
queued = IntField()
class EnqueueBatchItem(UpdateBatchItem):
queued: bool = BoolField()
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
class DequeueResponse(UpdateResponse):
dequeued = IntField()
class DequeueBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
class DequeueManyResponse(BatchResponse):
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
class ResetResponse(UpdateResponse):
deleted_indices = ListField(items_types=six.string_types)
dequeued = DictField()
frames = DictField()
events = DictField()
model_deleted = IntField()
deleted_models = IntField()
urls = DictField()
class ResetBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
deleted_models = IntField()
urls = DictField()
class ResetManyResponse(BatchResponse):
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
class TaskRequest(models.Base):
task = StringField(required=True)
class UpdateRequest(TaskRequest):
class TaskUpdateRequest(TaskRequest):
force = BoolField(default=False)
class UpdateRequest(TaskUpdateRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
force = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
@@ -71,6 +100,8 @@ class EnqueueRequest(UpdateRequest):
class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
class SetRequirementsRequest(TaskRequest):
@@ -81,10 +112,6 @@ class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)
class PublishResponse(UpdateResponse):
pass
class TaskData(models.Base):
"""
This is a partial description of task can be updated incrementally
@@ -104,6 +131,11 @@ class GetTypesRequest(models.Base):
projects = ListField(items_types=[str])
class TaskInputModel(models.Base):
name = StringField()
model = StringField()
class CloneRequest(TaskRequest):
new_task_name = StringField()
new_task_comment = StringField()
@@ -113,14 +145,15 @@ class CloneRequest(TaskRequest):
new_task_project = StringField()
new_task_hyperparams = DictField()
new_task_configuration = DictField()
new_task_container = DictField()
new_task_input_models = ListField([TaskInputModel])
execution_overrides = DictField()
validate_references = BoolField(default=False)
new_project_name = StringField()
class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArtifactId(models.Base):
@@ -130,13 +163,14 @@ class ArtifactId(models.Base):
)
class DeleteArtifactsRequest(TaskRequest):
class DeleteArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
class MultiTaskRequest(models.Base):
@@ -161,7 +195,7 @@ class ReplaceHyperparams(object):
all = "all"
class EditHyperParamsRequest(TaskRequest):
class EditHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamItem] = ListField(
[HyperParamItem], validators=Length(minimum_value=1)
)
@@ -169,7 +203,6 @@ class EditHyperParamsRequest(TaskRequest):
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
force = BoolField(default=False)
class HyperParamKey(models.Base):
@@ -177,11 +210,10 @@ class HyperParamKey(models.Base):
name = StringField(nullable=True)
class DeleteHyperParamsRequest(TaskRequest):
class DeleteHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
force = BoolField(default=False)
class GetConfigurationsRequest(MultiTaskRequest):
@@ -189,7 +221,7 @@ class GetConfigurationsRequest(MultiTaskRequest):
class GetConfigurationNamesRequest(MultiTaskRequest):
pass
skip_empty = BoolField(default=True)
class Configuration(models.Base):
@@ -199,17 +231,15 @@ class Configuration(models.Base):
description = StringField()
class EditConfigurationRequest(TaskRequest):
class EditConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[Configuration] = ListField(
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
force = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
class DeleteConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArchiveRequest(MultiTaskRequest):
@@ -219,3 +249,54 @@ class ArchiveRequest(MultiTaskRequest):
class ArchiveResponse(models.Base):
archived = IntField()
class TaskBatchRequest(BatchRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
validate_tasks = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
class PublishManyRequest(TaskBatchRequest):
publish_model = BoolField(default=True)
force = BoolField(default=False)
class AddUpdateModelRequest(TaskRequest):
name = StringField(required=True)
model = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
iteration = IntField()
class ModelItemKey(models.Base):
name = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
class DeleteModelsRequest(TaskRequest):
models: Sequence[ModelItemKey] = ListField(
[ModelItemKey], validators=Length(minimum_value=1)
)

View File

@@ -1,19 +1,17 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Set
import attr
import dpath
from boltons.iterutils import bucketize
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
@@ -28,19 +26,22 @@ from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
class VariantState(Base):
variant: str = StringField(required=True)
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
class MetricState(Base):
metric: str = StringField(required=True)
variants: Sequence[VariantState] = ListField([VariantState], required=True)
timestamp: int = IntField(default=0)
class TaskScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
@@ -49,7 +50,7 @@ class MetricScrollState(Base):
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
warning: str = StringField()
@@ -73,7 +74,7 @@ class DebugImagesIterator:
def get_task_events(
self,
company_id: str,
metrics: Sequence[Tuple[str, str]],
task_metrics: Mapping[str, Set[str]],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
@@ -83,8 +84,7 @@ class DebugImagesIterator:
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
state_.tasks = self._init_task_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
"""
@@ -92,16 +92,8 @@ class DebugImagesIterator:
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, state_)
for metric_state in state_.metrics:
metric_state.reset()
self._reinit_outdated_task_states(company_id, state_, task_metrics)
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
@@ -116,101 +108,124 @@ class DebugImagesIterator:
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
state.tasks,
)
)
return res
def _reinit_outdated_metric_states(
self, company_id, state: DebugImageEventsScrollState
def _reinit_outdated_task_states(
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]],
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
Determine the metrics for which new debug image events were added
since their states were initialized and re-init these states
"""
task_ids = set(metric.task for metric in state.metrics)
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
def get_last_update_times_for_task_metrics(
task: Task,
) -> Mapping[str, datetime]:
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
return {}
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
)
requested_metrics = task_metrics[task.id]
return {
stats.metric: stats.event_stats_by_type[
self.EVENT_TYPE.value
].last_update
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
]
and (not requested_metrics or stats.metric in requested_metrics)
}
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
update_times = {
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
}
task_metric_states = {
task_state.task: {
metric_state.metric: metric_state for metric_state in task_state.metrics
}
for task_state in state.tasks
}
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = set(
m
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
)
)
outdated_metrics = [
metric
for metric in state.metrics
if (metric.task, metric.name) in update_times
and update_times[metric.task, metric.name] > metric.timestamp
]
state.metrics = [
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
company_id,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
) -> TaskScrollState:
task = old_state.task
updated_state = first(uts for uts in updates if uts.task == task)
if not updated_state:
old_state.reset()
return old_state
updated_metrics = [m.metric for m in updated_state.metrics]
return TaskScrollState(
task=task,
metrics=[
*updated_state.metrics,
*(
old_metric
for old_metric in old_state.metrics
if old_metric.metric not in updated_metrics
),
],
)
state.tasks = [
merge_with_updated_task_states(task_state, updated_task_states)
for task_state in state.tasks
]
def _init_metric_states(
self, company_id: str, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]:
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, Set[str]]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
tasks = defaultdict(list)
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(
self._init_metric_states_for_task, company_id=company_id
),
tasks.items(),
)
)
task_metric_states = pool.map(
partial(self._init_metric_states_for_task, company_id=company_id),
task_metrics.items(),
)
return [
TaskScrollState(task=task, metrics=metric_states,)
for task, metric_states in zip(task_metrics, task_metric_states)
]
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
) -> Sequence[MetricScrollState]:
self, task_metrics: Tuple[str, Set[str]], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append({"terms": {"metric": list(metrics)}})
es_req: dict = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"metric": metrics}},
{"exists": {"field": "url"}},
]
}
},
"query": {"bool": {"must": must}},
"aggs": {
"metrics": {
"terms": {
@@ -254,20 +269,17 @@ class DebugImagesIterator:
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
def init_variant_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
Return new variant state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
state = VariantState(variant=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
@@ -275,102 +287,52 @@ class DebugImagesIterator:
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
MetricState(
metric=metric["key"],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
variants=[
init_variant_scroll_state(variant)
init_variant_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
task_state: TaskScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
Update task scroll state
"""
if metric.last_max_iter is None:
if not task_state.metrics:
return task_state.task, []
if task_state.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
{"term": {"task": task_state.task}},
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
{"exists": {"field": "url"}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if navigate_earlier and task_state.last_min_iter is not None:
range_condition = {"lt": task_state.last_min_iter}
elif not navigate_earlier and task_state.last_max_iter is not None:
range_condition = {"gt": task_state.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {
@@ -379,15 +341,26 @@ class DebugImagesIterator:
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"metrics": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"url": {"order": "desc"}}
}
}
},
}
},
}
@@ -397,80 +370,44 @@ class DebugImagesIterator:
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []
return task_state.task, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
invalid_iterations = {
(m.metric, v.variant): v.last_invalid_iteration
for m in task_state.metrics
for v in m.variants
}
def is_valid_event(event: dict) -> bool:
key = event.get("metric"), event.get("variant")
if key not in invalid_iterations:
return False
max_invalid = invalid_iterations[key]
return max_invalid is None or event.get("iter") > max_invalid
def get_iteration_events(it_: dict) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")
if is_valid_event(ev["_source"])
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
iterations = []
for it in dpath.get(es_res, "aggregations/iters/buckets"):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
task_state.last_max_iter = iterations[0]["iter"]
task_state.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]
return task_state.task, iterations

View File

@@ -1,5 +1,6 @@
import base64
import hashlib
import re
import zlib
from collections import defaultdict
from contextlib import closing
@@ -9,6 +10,7 @@ from typing import Sequence, Set, Tuple, Optional, Dict
import six
from elasticsearch import helpers
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
@@ -36,12 +38,13 @@ from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker
from apiserver.utilities.json import loads
EVENT_TYPES = set(map(attrgetter("value"), EventType))
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2**63 - 1
MIN_LONG = -2**63
class PlotFields:
@@ -49,11 +52,16 @@ class PlotFields:
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
img_source_regex = re.compile(
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
flags=re.IGNORECASE,
)
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
@@ -96,6 +104,7 @@ class EventBLL(object):
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={
@@ -145,6 +154,9 @@ class EventBLL(object):
iter = event.get("iter")
if iter is not None:
iter = int(iter)
if iter > MAX_LONG or iter < MIN_LONG:
errors_per_type[invalid_iteration_error] += 1
continue
event["iter"] = iter
# used to have "values" to indicate array. no need anymore
@@ -201,47 +213,55 @@ class EventBLL(object):
)
added = 0
if actions:
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
with translate_errors_context():
if actions:
chunk_size = 500
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
# this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
)
if not added:
raise errors.bad_request.EventsNotAdded(**errors_per_type)
@@ -269,6 +289,11 @@ class EventBLL(object):
event[PlotFields.plot_len] = plot_len
if validate:
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
urls = {match for match in self.img_source_regex.findall(plot_str)}
if urls:
event[PlotFields.source_urls] = list(urls)
if compression_threshold and plot_len >= compression_threshold:
event[PlotFields.plot_data] = base64.encodebytes(
zlib.compress(plot_str.encode(), level=1)
@@ -430,6 +455,9 @@ class EventBLL(object):
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
if event_type in (EventType.metrics_plot, EventType.all):
self.uncompress_plots(events)
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
@@ -501,7 +529,7 @@ class EventBLL(object):
scroll_id: str = None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
@@ -595,6 +623,41 @@ class EventBLL(object):
return events, total_events, next_scroll_id
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
if scroll_id == self.empty_scroll:
return [], None
if scroll_id:
es_res = self.es.scroll(scroll_id=scroll_id, scroll="10m")
else:
if check_empty_data(self.es, company_id, EventType.metrics_plot):
return [], None
es_req = {
"size": 1000,
"_source": [PlotFields.source_urls],
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"exists": {"field": PlotFields.source_urls}},
]
}
},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=EventType.metrics_plot,
body=es_req,
scroll="10m",
)
events, _, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id
def get_task_events(
self,
company_id: str,
@@ -672,6 +735,9 @@ class EventBLL(object):
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
if event_type in (EventType.metrics_plot, EventType.all):
self.uncompress_plots(events)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
@@ -892,3 +958,20 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
"""
Delete mutliple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
)
return es_res.get("deleted", 0)

View File

@@ -4,8 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from apiserver.utilities import extract_properties_to_lists
from apiserver.utilities.stringenum import StringEnum
from apiserver.bll.util import extract_properties_to_lists
from apiserver.config_repo import config
log = config.logger(__file__)

View File

@@ -1,18 +1,129 @@
from typing import Optional, Sequence
from mongoengine import Q
from datetime import datetime
from typing import Callable, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.utils import get_company_or_none_constraint
from apiserver.database.model.task.task import Task, TaskStatus
class ModelBLL:
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@classmethod
def get_company_model_by_id(
cls, company_id: str, model_id: str, only_fields=None
) -> Model:
query = dict(company=company_id, id=model_id)
qs = Model.objects(**query)
if only_fields:
qs = qs.only(*only_fields)
model = qs.first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
return model
@classmethod
def publish_model(
cls,
model_id: str,
company_id: str,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
published_task = None
if model.task and publish_task_func:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res
)
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
return updated, published_task
@classmethod
def delete_model(
cls, model_id: str, company_id: str, force: bool
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
model_id=model_id,
only_fields=("id", "task", "project", "uri"),
)
deleted_model_id = f"{deleted_prefix}{model_id}"
using_tasks = Task.objects(models__input__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
"as execution model, use force=True to delete",
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
Task._get_collection().update_many(
filter={"_id": {"$in": [t.id for t in using_tasks]}},
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
)
return unarchived

View File

@@ -1,12 +1,8 @@
from collections import defaultdict
from enum import Enum
from operator import itemgetter
from typing import Sequence, Dict, Optional
from mongoengine import Q
from typing import Sequence, Dict
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
@@ -65,34 +61,3 @@ class OrgBLL:
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
return self._task_tags if entity == Tags.Task else self._model_tags
@classmethod
def get_parent_tasks(
cls,
company_id: str,
projects: Sequence[str],
state: Optional[EntityVisibility] = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
query &= Q(project__in=projects)
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
return []
parents = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
allow_public=True,
override_projection=("id", "name", "project.name"),
)
return sorted(parents, key=itemgetter("name"))

View File

@@ -5,6 +5,7 @@ from mongoengine import Q
from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -40,7 +41,7 @@ class _TagsCache:
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project=project)
query &= Q(project__in=project_ids_with_children([project]))
return self.db_cls.objects(query).distinct(field)

View File

@@ -1 +1,2 @@
from .project_bll import ProjectBLL
from .sub_projects import _ids_with_children as project_ids_with_children

View File

@@ -1,40 +1,187 @@
import itertools
from collections import defaultdict
from datetime import datetime
from typing import Sequence, Optional, Type
from functools import reduce
from itertools import groupby
from operator import itemgetter
from typing import (
Sequence,
Optional,
Type,
Tuple,
Dict,
Set,
TypeVar,
Callable,
Mapping,
)
from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
from apiserver.database.utils import get_options, get_company_or_none_constraint
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from .sub_projects import (
_reposition_project_with_children,
_ensure_project,
_validate_project_name,
_update_subproject_names,
_save_under_parent,
_get_sub_projects,
_ids_with_children,
_ids_with_parents,
_get_project_depth,
)
log = config.logger(__file__)
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
class ProjectBLL:
@classmethod
def get_active_users(
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
) -> set:
def merge_project(
cls, company, source_id: str, destination_id: str
) -> Tuple[int, int, Set[str]]:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
Move all the tasks and sub projects from the source project to the destination
Remove the source project
Return the amounts of moved entities and subprojects + set of all the affected project ids
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if project_ids:
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
with TimingContext("mongo", "move_project"):
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
parent=source_id
)
source = Project.get(company, source_id)
destination = Project.get(company, destination_id)
return res
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
moved_entities = 0
for entity_type in (Task, Model):
moved_entities += entity_type.objects(
company=company,
project=source_id,
system_tags__nin=[EntityVisibility.archived.value],
).update(upsert=False, project=destination_id)
moved_sub_projects = 0
for child in Project.objects(company=company, parent=source_id):
_reposition_project_with_children(
project=child,
children=[c for c in children if c.parent == child.id],
parent=destination,
)
moved_sub_projects += 1
affected = {source.id, *(source.path or [])}
source.delete()
if destination:
destination.update(last_update=datetime.utcnow())
affected.update({destination.id, *(destination.path or [])})
return moved_entities, moved_sub_projects, affected
@staticmethod
def validate_projects_depth(
projects: Sequence[Project], old_parent_depth: int, new_parent_depth: int
):
for current in projects:
current_depth = len(current.path) + 1
if current_depth - old_parent_depth + new_parent_depth > max_depth:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
@classmethod
def move_project(
cls, company: str, user: str, project_id: str, new_location: str
) -> Tuple[int, Set[str]]:
"""
Move project with its sub projects from its current location to the target one.
If the target location does not exist then it will be created. If it exists then
it should be writable. The source location should be writable too.
Return the number of moved projects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
project = Project.get(company, project_id)
old_parent_id = project.parent
old_parent = (
Project.get_for_writing(company=project.company, id=old_parent_id)
if old_parent_id
else None
)
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
cls.validate_projects_depth(
projects=[project, *children],
old_parent_depth=len(project.path),
new_parent_depth=_get_project_depth(new_location),
)
new_parent = _ensure_project(company=company, user=user, name=new_location)
new_parent_id = new_parent.id if new_parent else None
if old_parent_id == new_parent_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
now = datetime.utcnow()
affected = set()
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
return moved, affected
@classmethod
def update(cls, company: str, project_id: str, **fields):
with TimingContext("mongo", "projects_update"):
project = Project.get_for_writing(company=company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
new_name = fields.pop("name", None)
if new_name:
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
fields["name"] = new_name
fields["last_update"] = datetime.utcnow()
updated = project.update(upsert=False, **fields)
if new_name:
old_name = project.name
project.name = new_name
children = _get_sub_projects(
[project.id], _only=("id", "name", "path")
)[project.id]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
return updated
@classmethod
def create(
@@ -42,7 +189,7 @@ class ProjectBLL:
user: str,
company: str,
name: str,
description: str,
description: str = "",
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
@@ -51,6 +198,10 @@ class ProjectBLL:
Create a new project.
Returns project ID
"""
if _get_project_depth(name) > max_depth:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
name, location = _validate_project_name(name)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -64,7 +215,11 @@ class ProjectBLL:
created=now,
last_update=now,
)
project.save()
parent = _ensure_project(company=company, user=user, name=location)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
return project.id
@classmethod
@@ -92,6 +247,7 @@ class ProjectBLL:
raise errors.bad_request.InvalidProjectId(id=project_id)
return project_id
project_name, _ = _validate_project_name(project_name)
project = Project.objects(company=company, name=project_name).only("id").first()
if project:
return project.id
@@ -125,13 +281,439 @@ class ProjectBLL:
company=company,
project_id=project,
project_name=project_name,
description="Auto-generated during move",
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
@classmethod
def make_projects_get_all_pipelines(
cls,
company_id: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: cls.archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery(
{"$not": cls.archived_tasks_cond}
)
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"company": {"$in": [None, "", company_id]},
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
T = TypeVar("T")
@staticmethod
def aggregate_project_data(
func: Callable[[T, T], T],
project_ids: Sequence[str],
child_projects: Mapping[str, Sequence[Project]],
data: Mapping[str, T],
) -> Dict[str, T]:
"""
Given a list of project ids and data collected over these projects and their subprojects
For each project aggregates the data from all of its subprojects
"""
aggregated = {}
if not data:
return aggregated
for pid in project_ids:
relevant_projects = {p.id for p in child_projects.get(pid, [])} | {pid}
relevant_data = [data for p, data in data.items() if p in relevant_projects]
if not relevant_data:
continue
aggregated[pid] = reduce(func, relevant_data)
return aggregated
@classmethod
def get_project_stats(
cls,
company: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
project_ids_with_children = set(project_ids) | {
c.id for c in itertools.chain.from_iterable(child_projects.values())
}
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
company,
project_ids=list(project_ids_with_children),
specific_state=specific_state,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
def sum_status_count(
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
) -> Dict[str, dict]:
return {
section: {
status: nested_get(a, (section, status), 0)
+ nested_get(b, (section, status), 0)
for status in set(a.get(section, {})) | set(b.get(section, {}))
}
for section in set(a) | set(b)
}
status_count = cls.aggregate_project_data(
func=sum_status_count,
project_ids=project_ids,
child_projects=child_projects,
data=status_count,
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def sum_runtime(
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
for section in set(a) | set(b)
}
runtime = cls.aggregate_project_data(
func=sum_runtime,
project_ids=project_ids,
child_projects=child_projects,
data=runtime,
)
def get_status_counts(project_id, section):
return {
"total_runtime": nested_get(runtime, (project_id, section), 0),
"status_count": nested_get(
status_count, (project_id, section), default_counts
),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
stats = {
project: {
task_state.value: get_status_counts(project, task_state.value)
for task_state in report_for_states
}
for project in project_ids
}
children = {
project: sorted(
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
key=itemgetter("name"),
)
for project in project_ids
}
return stats, children
@classmethod
def get_active_users(
cls,
company,
project_ids: Sequence[str],
user_ids: Optional[Sequence[str]] = None,
) -> Set[str]:
"""
Get the set of user ids that created tasks/models/dataviews in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
projects_query = query
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res
@classmethod
def get_projects_with_active_user(
cls,
company: str,
users: Sequence[str],
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
) -> Sequence[str]:
"""
Get the projects ids where user created any tasks including all the parents of these projects
If project ids are specified then filter the results by these project ids
"""
query = Q(user__in=users)
if allow_public:
query &= get_company_or_none_constraint(company)
else:
query &= Q(company=company)
user_projects_query = query
if project_ids:
ids_with_children = _ids_with_children(project_ids)
query &= Q(project__in=ids_with_children)
user_projects_query &= Q(id__in=ids_with_children)
res = {p.id for p in Project.objects(user_projects_query).only("id")}
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="project"))
res = list(res)
if not res:
return res
ids_with_parents = _ids_with_parents(res)
if project_ids:
return [pid for pid in ids_with_parents if pid in project_ids]
return ids_with_parents
@classmethod
def get_task_parents(
cls,
company_id: str,
projects: Sequence[str],
include_subprojects: bool,
state: Optional[EntityVisibility] = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
return []
parents = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
allow_public=True,
override_projection=("id", "name", "project.name"),
)
return sorted(parents, key=itemgetter("name"))
@classmethod
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@classmethod
def get_model_frameworks(cls, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@classmethod
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
"""
Returns the amount of task/dataviews/models per requested project
Use separate aggregation calls on Task/Dataview/Model instead of lookup
aggregation on projects in order not to hit memory limits on large tasks
"""
if not project_ids:
return {}
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
},
{
"$project": {"project": 1}
},
{
"$group": {
"_id": "$project",
"count": {"$sum": 1},
}
}
]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {
data["_id"]: data["count"]
for data in cls_.aggregate(pipeline)
}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {
"own_tasks": tasks.get(pid, 0),
"own_models": models.get(pid, 0),
}
for pid in project_ids
}

View File

@@ -0,0 +1,154 @@
from typing import Tuple, Set, Sequence
import attr
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
TaskUrls,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes
from apiserver.timing_context import TimingContext
from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
@attr.s(auto_attribs=True)
class DeleteProjectResult:
deleted: int = 0
disassociated_tasks: int = 0
deleted_models: int = 0
deleted_tasks: int = 0
urls: TaskUrls = None
def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_ids = _ids_with_children([project_id])
if not force:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
if not delete_contents:
with TimingContext("mongo", "update_children"):
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(
project=None
)
res = DeleteProjectResult(disassociated_tasks=updated_count)
else:
deleted_models, model_urls = _delete_models(projects=project_ids)
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
)
res = DeleteProjectResult(
deleted_tasks=deleted_tasks,
deleted_models=deleted_models,
urls=TaskUrls(
model_urls=list(model_urls),
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
),
)
affected = {*project_ids, *(project.path or [])}
res.deleted = Project.objects(id__in=project_ids).delete()
return res, affected
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
"""
Delete only the task themselves and their non published version.
Child models under the same project are deleted separately.
Children tasks should be deleted in the same api call.
If any child entities are left in another projects then updated their parent task to None
"""
tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
if not tasks:
return 0, set(), set()
task_ids = {t.id for t in tasks}
with TimingContext("mongo", "delete_tasks_update_children"):
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
event_urls, artifact_urls = set(), set()
for task in tasks:
event_urls.update(collect_debug_image_urls(company, task.id))
event_urls.update(collect_plot_image_urls(company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls.update(
{
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
)
event_bll.delete_multi_task_events(company, list(task_ids))
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
"""
Delete project models and update the tasks from other projects
that reference them to reference None.
"""
with TimingContext("mongo", "delete_models"):
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set()
model_ids = list({m.id for m in models})
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
urls = {m.uri for m in models if m.uri}
deleted = models.delete()
return deleted, urls

View File

@@ -0,0 +1,176 @@
import itertools
from datetime import datetime
from typing import Tuple, Optional, Sequence, Mapping
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model.project import Project
name_separator = "/"
def _get_project_depth(project_name: str) -> int:
return len(list(filter(None, project_name.split(name_separator))))
def _validate_project_name(project_name: str) -> Tuple[str, str]:
"""
Remove redundant '/' characters. Ensure that the project name is not empty
Return the cleaned up project name and location
"""
name_parts = list(filter(None, project_name.split(name_separator)))
if not name_parts:
raise errors.bad_request.InvalidProjectName(name=project_name)
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
"""
Makes sure that the project with the given name exists
If needed auto-create the project and all the missing projects in the path to it
Return the project
"""
name = name.strip(name_separator)
if not name:
return None
project = _get_writable_project_from_name(company, name)
if project:
return project
now = datetime.utcnow()
name, location = _validate_project_name(name)
project = Project(
id=database.utils.id(),
user=user,
company=company,
created=now,
last_update=now,
name=name,
description="",
)
parent = _ensure_project(company, user, location)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
return project
def _save_under_parent(project: Project, parent: Optional[Project]):
"""
Save the project under the given parent project or top level (parent=None)
Check that the project location matches the parent name
"""
location, _, _ = project.name.rpartition(name_separator)
if not parent:
if location:
raise ValueError(
f"Project location {location} does not match empty parent name"
)
project.parent = None
project.path = []
project.save()
return
if location != parent.name:
raise ValueError(
f"Project location {location} does not match parent name {parent.name}"
)
project.parent = parent.id
project.path = [*(parent.path or []), parent.id]
project.save()
def _get_writable_project_from_name(
company,
name,
_only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"),
) -> Optional[Project]:
"""
Return a project from name. If the project not found then return None
"""
qs = Project.objects(company=company, name=name)
if _only:
qs = qs.only(*_only)
return qs.first()
def _get_sub_projects(
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
) -> Mapping[str, Sequence[Project]]:
"""
Return the list of child projects of all the levels for the parent project ids
"""
qs = Project.objects(path__in=project_ids)
if _only:
_only = set(_only) | {"path"}
qs = qs.only(*_only)
subprojects = list(qs)
return {
pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids
}
def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with all the parent projects
"""
projects = Project.objects(id__in=project_ids).only("id", "path")
parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path))
return list({*(p.id for p in projects), *parent_ids})
def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with the ids of all the subprojects
"""
subprojects = Project.objects(path__in=project_ids).only("id")
return list({*project_ids, *(child.id for child in subprojects)})
def _update_subproject_names(
project: Project,
children: Sequence[Project],
old_name: str,
update_path: bool = False,
old_path: Sequence[str] = None,
) -> int:
"""
Update sub project names when the base project name changes
Optionally update the paths
"""
updated = 0
for child in children:
child_suffix = name_separator.join(
child.name.split(name_separator)[len(old_name.split(name_separator)) :]
)
updates = {"name": name_separator.join((project.name, child_suffix))}
if update_path:
updates["path"] = project.path + child.path[len(old_path) :]
updated += child.update(upsert=False, **updates)
return updated
def _reposition_project_with_children(
project: Project, children: Sequence[Project], parent: Project
) -> int:
new_location = parent.name if parent else None
old_name = project.name
old_path = project.path
project.name = name_separator.join(
filter(None, (new_location, project.name.split(name_separator)[-1]))
)
_save_under_parent(project, parent=parent)
moved = 1 + _update_subproject_names(
project=project,
children=children,
old_name=old_name,
update_path=True,
old_path=old_path,
)
return moved

View File

@@ -32,6 +32,7 @@ class QueueBLL(object):
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[Sequence[dict]] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
@@ -43,6 +44,7 @@ class QueueBLL(object):
name=name,
tags=tags or [],
system_tags=system_tags or [],
metadata=metadata,
last_update=now,
)
queue.save()

View File

@@ -45,7 +45,7 @@ class StatisticsReporter:
def start_reporter(cls):
"""
Periodically send statistics reports for companies who have opted in.
Note: in trains we usually have only a single company
Note: in clearml we usually have only a single company
"""
if not cls.supported:
return

View File

@@ -3,5 +3,4 @@ from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
split_by,
)

View File

@@ -1,10 +1,10 @@
from hashlib import md5
from operator import itemgetter
from typing import Sequence
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
Calculate id from 'key' and 'mode' fields
Return hash on on the id so that it will not contain mongo illegal characters
"""
key_hash: str = md5(artifact["key"].encode()).hexdigest()
key_hash: str = hash_field_name(artifact["key"])
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
return f"{key_hash}_{mode}"
@@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
nested_set(
fields,
artifacts_field,
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
value=sorted(artifacts.values(), key=itemgetter("key")),
)

View File

@@ -175,21 +175,23 @@ class HyperParams:
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
) -> Dict[str, list]:
with TimingContext("mongo", "get_configuration_names"):
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
*([skip_empty_condition] if skip_empty else []),
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
with TimingContext("mongo", "get_configuration_names"):
tasks = Task.aggregate(pipeline)
return {

View File

@@ -185,6 +185,7 @@ def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
("execution.docker_cmd", "container")
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]

View File

@@ -1,17 +1,20 @@
import json
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, timedelta
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
from six import string_types
import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@@ -21,21 +24,29 @@ from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
TaskStatus,
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
external_task_types,
ModelItem,
Models,
DEFAULT_ARTIFACT_MODE,
TaskModelNames,
TaskModelTypes,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change, update_project_time
from .utils import (
ChangeStatusRequest,
update_project_time,
deleted_prefix,
)
log = config.logger(__file__)
org_bll = OrgBLL()
@@ -44,22 +55,9 @@ project_bll = ProjectBLL()
class TaskBLL:
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
def __init__(self, events_es=None, redis=None):
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_task_with_access(
@@ -151,19 +149,20 @@ class TaskBLL:
)
@staticmethod
def validate_execution_model(task, allow_only_public=False):
if not task.execution or not task.execution.model:
def validate_input_models(task, allow_only_public=False):
if not task.models.input:
return
company = None if allow_only_public else task.company
model_id = task.execution.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company)
).first()
if not model:
raise errors.bad_request.InvalidModelId(model=model_id)
model_ids = set(m.model for m in task.models.input)
models = Model.objects(
Q(id__in=model_ids) & get_company_or_none_constraint(company)
).only("id")
missing = model_ids - {m.id for m in models}
if missing:
raise errors.bad_request.InvalidModelId(models=missing)
return model
return
@classmethod
def clone_task(
@@ -179,7 +178,9 @@ class TaskBLL:
system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
container: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
) -> Tuple[Task, dict]:
@@ -195,10 +196,29 @@ class TaskBLL:
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
now = datetime.utcnow()
if input_models:
input_models = [
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
]
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
execution_model_overriden = execution_overrides.get("model") is not None
execution_model = execution_overrides.pop("model", None)
if not input_models and execution_model:
input_models = [
ModelItem(
model=execution_model,
name=TaskModelNames[TaskModelTypes.input],
updated=now,
)
]
docker_cmd = execution_overrides.pop("docker_cmd", None)
if not container and docker_cmd:
image, _, arguments = docker_cmd.partition(" ")
container = {"image": image, "arguments": arguments}
artifacts_prepare_for_save({"execution": execution_overrides})
params_dict["execution"] = {}
@@ -207,6 +227,8 @@ class TaskBLL:
if legacy_value is not None:
params_dict["execution"] = legacy_value
escape_dict_field(execution_overrides, "model_labels")
execution_dict.update(execution_overrides)
params_prepare_for_save(params_dict, previous_task=task)
@@ -216,7 +238,7 @@ class TaskBLL:
execution_dict["artifacts"] = {
k: a
for k, a in artifacts.items()
if a.get("mode") != ArtifactModes.output
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
}
execution_dict.pop("queue", None)
@@ -227,12 +249,10 @@ class TaskBLL:
project_name=new_project_name,
user=user_id,
company=company_id,
description="Auto-generated while cloning",
description="",
)
new_project_data = {"id": project, "name": new_project_name}
now = datetime.utcnow()
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
if not input_tags:
return input_tags
@@ -240,10 +260,16 @@ class TaskBLL:
return [
tag
for tag in input_tags
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
if tag
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
else None
)
new_task = Task(
id=create_id(),
user=user_id,
@@ -253,7 +279,7 @@ class TaskBLL:
last_change=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
parent=parent or parent_task,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
@@ -262,13 +288,15 @@ class TaskBLL:
output=Output(destination=task.output.destination)
if task.output
else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
@@ -295,7 +323,7 @@ class TaskBLL:
def validate(
cls,
task: Task,
validate_model=True,
validate_models=True,
validate_parent=True,
validate_project=True,
):
@@ -307,6 +335,7 @@ class TaskBLL:
if (
validate_parent
and task.parent
and not task.parent.startswith(deleted_prefix)
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
@@ -318,16 +347,23 @@ class TaskBLL:
if validate_project and not project:
raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task)
if validate_models:
cls.validate_input_models(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
def get_unique_metric_variants(
company_id, project_ids: Sequence[str], include_subprojects: bool
):
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]},
**({"project": {"$in": project_ids}} if project_ids else {}),
company={"$in": [None, "", company_id]}, **project_constraint,
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
@@ -372,6 +408,7 @@ class TaskBLL:
tasks = Task.objects(id__in=task_ids, company=company_id).only(
"status", "started"
)
count = 0
for task in tasks:
updates = extra_updates
if task.status == TaskStatus.in_progress and task.started:
@@ -381,12 +418,13 @@ class TaskBLL:
).total_seconds(),
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
count += Task.objects(id=task.id, company=company_id).update(
upsert=False,
last_update=last_update,
last_change=last_update,
**updates,
)
return count
@staticmethod
def update_statistics(
@@ -449,168 +487,27 @@ class TaskBLL:
}
extra_updates["metric_stats"] = metric_stats
TaskBLL.set_last_update(
return TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
last_update=last_update,
**extra_updates,
)
@classmethod
def model_set_ready(
cls,
model_id: str,
company_id: str,
publish_task: bool,
force_publish_task: bool = False,
) -> tuple:
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
elif model.ready:
raise errors.bad_request.ModelIsReady(**query)
published_task_data = {}
if model.task and publish_task:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
published_task_data["data"] = cls.publish_task(
task_id=model.task,
company_id=company_id,
publish_model=False,
force=force_publish_task,
)
published_task_data["id"] = model.task
updated = model.update(upsert=False, ready=True)
return updated, published_task_data
@classmethod
def publish_task(
cls,
task_id: str,
company_id: str,
publish_model: bool,
force: bool,
status_reason: str = "",
status_message: str = "",
) -> dict:
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.output.model and publish_model:
output_model = (
Model.objects(id=task.output.model)
.only("id", "task", "ready")
.first()
)
if output_model and not output_model.ready:
cls.model_set_ready(
model_id=task.output.model,
company_id=company_id,
publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
@classmethod
def stop_task(
cls,
task_id: str,
company_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
@@ -618,7 +515,7 @@ class TaskBLL:
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
**project_constraint,
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
@@ -632,6 +529,8 @@ class TaskBLL:
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
@@ -639,16 +538,9 @@ class TaskBLL:
"results": {"$push": "$$ROOT"},
}
},
{
"$project": {
"total": 1,
"results": {"$slice": ["$results", page * page_size, page_size]},
}
},
]
with translate_errors_context():
result = next(Task.aggregate(pipeline), None)
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
@@ -669,6 +561,106 @@ class TaskBLL:
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
company_constraint = {"company": {"$in": [None, "", company_id]}}
else:
company_constraint = {"company": company_id}
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,
@@ -677,10 +669,10 @@ class TaskBLL:
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
).execute(enqueue_status=None)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):

View File

@@ -0,0 +1,278 @@
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
import attr
from boltons.iterutils import partition
from mongoengine import QuerySet, Document
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import PlotFields
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.timing_context import TimingContext
event_bll = EventBLL()
T = TypeVar("T", bound=Document)
class DocumentGroup(List[T]):
"""
Operate on a list of documents as if they were a query result
"""
def __init__(self, document_type: Type[T], documents: Iterable[T]):
super(DocumentGroup, self).__init__(documents)
self.type = document_type
@property
def ids(self) -> Set[str]:
return {obj.id for obj in self}
def objects(self, *args, **kwargs) -> QuerySet:
return self.type.objects(id__in=self.ids, *args, **kwargs)
class TaskOutputs(Generic[T]):
"""
Split task outputs of the same type by the ready state
"""
published: DocumentGroup[T]
draft: DocumentGroup[T]
def __init__(
self,
is_published: Callable[[T], bool],
document_type: Type[T],
children: Iterable[T],
):
"""
:param is_published: predicate returning whether items is considered published
:param document_type: type of output
:param children: output documents
"""
self.published, self.draft = map(
lambda x: DocumentGroup(document_type, x),
partition(children, key=is_published),
)
@attr.s(auto_attribs=True)
class TaskUrls:
model_urls: Sequence[str]
event_urls: Sequence[str]
artifact_urls: Sequence[str]
def __add__(self, other: "TaskUrls"):
if not other:
return self
return TaskUrls(
model_urls=list(set(self.model_urls) | set(other.model_urls)),
event_urls=list(set(self.event_urls) | set(other.event_urls)),
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
)
@attr.s(auto_attribs=True)
class CleanupResult:
"""
Counts of objects modified in task cleanup operation
"""
updated_children: int
updated_models: int
deleted_models: int
urls: TaskUrls = None
def __add__(self, other: "CleanupResult"):
if not other:
return self
return CleanupResult(
updated_children=self.updated_children + other.updated_children,
updated_models=self.updated_models + other.updated_models,
deleted_models=self.deleted_models + other.deleted_models,
urls=self.urls + other.urls if self.urls else other.urls,
)
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
urls = set()
next_scroll_id = None
with TimingContext("es", "collect_plot_image_urls"):
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(company: str, task: str) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
metrics = event_bll.get_metrics_and_variants(
company_id=company, task_id=task, event_type=EventType.metrics_image
)
if not metrics:
return set()
task_metrics = {task: set(metrics)}
scroll_id = None
urls = set()
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company,
task_metrics=task_metrics,
iter_count=100,
state_id=scroll_id,
)
if not res.metric_events or not any(
iterations for _, iterations in res.metric_events
):
break
scroll_id = res.next_scroll_id
for task, iterations in res.metric_events:
urls.update(ev.get("url") for it in iterations for ev in it["events"])
urls.discard({None})
return urls
def cleanup_task(
task: Task,
force: bool = False,
update_children=True,
return_file_urls=False,
delete_output_models=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
:param task: task object
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
models = verify_task_children_and_ouptuts(task, force)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls:
event_urls = collect_debug_image_urls(task.company, task.id)
event_urls.update(collect_plot_image_urls(task.company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls = {
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
deleted_task_id = f"{deleted_prefix}{task.id}"
if update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id
)
else:
updated_children = 0
if models.draft and delete_output_models:
with TimingContext("mongo", "delete_models"):
deleted_models = models.draft.objects().delete()
else:
deleted_models = 0
if models.published and update_children:
with TimingContext("mongo", "update_task_models"):
updated_models = models.published.objects().update(task=deleted_task_id)
else:
updated_models = 0
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
urls=TaskUrls(
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
model_urls=list(model_urls),
)
if return_file_urls
else None,
)
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True",
task=task.id,
children=published_children_count,
)
with TimingContext("mongo", "get_task_models"):
models = TaskOutputs(
attrgetter("ready"),
Model,
Model.objects(task=task.id).only("id", "task", "ready"),
)
if not force and models.published:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(models.published),
)
if task.models and task.models.output:
with TimingContext("mongo", "get_task_output_model"):
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=output_model.id,
)
models.published.append(output_model)
else:
models.draft.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models"
)
input_models = {
m.model
for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
}
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)
)
return models

View File

@@ -0,0 +1,380 @@
from datetime import datetime
from typing import Callable, Any, Tuple, Union
from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
TaskStatus,
Task,
TaskSystemTags,
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
)
from apiserver.utilities.dicts import nested_set
queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task,
company_id=company_id,
only=(
"id",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
requires_write_access=True,
)
try:
TaskBLL.dequeue_and_change_status(
task, company_id, status_message, status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
)
def unarchive_task(
task: str, company_id: str, status_message: str, status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task, company_id=company_id, only=("id",), requires_write_access=True,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
)
def dequeue_task(
task_id: str,
company_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=status_message,
status_reason=status_reason,
)
return 1, res
def enqueue_task(
task_id: str,
company_id: str,
queue_id: str,
status_message: str,
status_reason: str,
validate: bool = False,
) -> Tuple[int, dict]:
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if validate:
TaskBLL.validate(task)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
).execute(enqueue_status=task.status)
try:
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
except Exception:
# failed enqueueing, revert to previous state
ChangeStatusRequest(
task=task,
current_status_override=TaskStatus.queued,
new_status=task.status,
force=True,
status_reason="failed enqueueing",
).execute(enqueue_status=None)
raise
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
nested_set(res, ("fields", "execution.queue"), queue_id)
return 1, res
def delete_task(
task_id: str,
company_id: str,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if (
task.status != TaskStatus.created
and EntityVisibility.archived.value not in task.system_tags
and not force
):
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
cleanup_res = cleanup_task(
task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
)
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
task.delete()
update_project_time(task.project)
return 1, task, cleanup_res
def reset_task(
task_id: str,
company_id: str,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
dequeued = {}
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleaned_up = cleanup_task(
task,
force=force,
update_children=False,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
)
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
set__models__output=[],
unset__output__result=1,
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
if task.execution and task.execution.artifacts:
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts.items()
if artifact.mode == ArtifactModes.input
}
)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
force=force,
status_reason="reset",
status_message="reset",
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
enqueue_status=None,
**updates,
)
return dequeued, cleaned_up, res
def publish_task(
task_id: str,
company_id: str,
force: bool,
publish_model_func: Callable[[str, str], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.models and task.models.output and publish_model_func:
model_id = task.models.output[-1].model
model = (
Model.objects(id=model_id, company=company_id)
.only("id", "ready")
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
def stop_task(
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import TypeVar, Callable, Tuple, Sequence, Union
from typing import Sequence, Union
import attr
import six
@@ -13,6 +13,7 @@ from apiserver.timing_context import TimingContext
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
deleted_prefix = "__DELETED__"
@typed_attrs
@@ -105,7 +106,7 @@ def validate_status_change(current_status, new_status):
state_machine = {
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
TaskStatus.in_progress: {
TaskStatus.stopped,
TaskStatus.failed,
@@ -116,6 +117,7 @@ state_machine = {
TaskStatus.closed,
TaskStatus.created,
TaskStatus.failed,
TaskStatus.queued,
TaskStatus.in_progress,
TaskStatus.published,
TaskStatus.publishing,
@@ -163,22 +165,6 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
T = TypeVar("T")
def split_by(
condition: Callable[[T], bool], items: Sequence[T]
) -> Tuple[Sequence[T], Sequence[T]]:
"""
split "items" to two lists by "condition"
"""
applied = zip(map(condition, items), items)
return (
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:

View File

@@ -1,33 +1,25 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from operator import itemgetter
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
from typing import (
Optional,
Callable,
Dict,
Any,
Set,
Iterable,
Tuple,
Sequence,
TypeVar,
)
from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model import AttributedDocument
from apiserver.database.model.settings import Settings
def extract_properties_to_lists(
key_names: Sequence[str],
data: Sequence[dict],
extract_func: Optional[Callable[[dict], Tuple]] = None,
) -> dict:
"""
Given a list of dictionaries and names of dictionary keys
builds a dictionary with the requested keys and values lists
:param key_names: names of the keys in the resulting dictionary
:param data: sequence of dictionaries to extract values from
:param extract_func: the optional callable that extracts properties
from a dictionary and put them in a tuple in the order corresponding to
key_names. If not specified then properties are extracted according to key_names
"""
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
return dict(zip(key_names, map(list, value_sequences)))
class SetFieldsResolver:
"""
The class receives set fields dictionary
@@ -115,3 +107,28 @@ def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
)
return wrapper
T = TypeVar("T")
def run_batch_operation(
func: Callable[[str], T], ids: Sequence[str]
) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
results = list()
failures = list()
for _id in ids:
try:
results.append((_id, func(_id)))
except APIError as err:
failures.append(
{
"id": _id,
"error": {
"codes": [err.code, err.subcode],
"msg": err.msg,
"data": err.error_data,
},
}
)
return results, failures

View File

@@ -6,9 +6,10 @@ from functools import reduce
from os import getenv
from os.path import expandvars
from pathlib import Path
from typing import List, Any, TypeVar
from typing import List, Any, TypeVar, Sequence
from pyhocon import ConfigTree, ConfigFactory
from boltons.iterutils import first
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
from pyparsing import (
ParseFatalException,
ParseException,
@@ -18,8 +19,8 @@ from pyparsing import (
from apiserver.utilities import json
EXTRA_CONFIG_PATHS = ("/opt/trains/config",)
EXTRA_CONFIG_PATH_OVERRIDE_VAR = "TRAINS_CONFIG_DIR"
EXTRA_CONFIG_PATHS = ("/opt/trains/config", "/opt/clearml/config")
DEFAULT_PREFIXES = ("clearml", "trains")
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"
@@ -30,7 +31,10 @@ class BasicConfig:
default_config_dir = "default"
def __init__(
self, folder: str = None, verbose: bool = True, prefix: str = "trains"
self,
folder: str = None,
verbose: bool = True,
prefix: Sequence[str] = DEFAULT_PREFIXES,
):
folder = (
Path(folder)
@@ -41,8 +45,16 @@ class BasicConfig:
raise ValueError("Invalid configuration folder")
self.verbose = verbose
self.prefix = prefix
self.extra_config_values_env_key_prefix = f"{self.prefix.upper()}__"
self.extra_config_path_override_var = [
f"{p.upper()}_CONFIG_DIR" for p in prefix
]
self.prefix = prefix[0]
self.extra_config_values_env_key_prefix = [
f"{p.upper()}{self.extra_config_values_env_key_sep}"
for p in reversed(prefix)
]
self._paths = [folder, *self._get_paths()]
self._config = self._reload()
@@ -73,24 +85,24 @@ class BasicConfig:
def _read_extra_env_config_values(self) -> ConfigTree:
""" Loads extra configuration from environment-injected values """
result = ConfigTree()
prefix = self.extra_config_values_env_key_prefix
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = (
key[len(prefix) :]
.replace(self.extra_config_values_env_key_sep, ".")
.lower()
)
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
for prefix in self.extra_config_values_env_key_prefix:
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = (
key[len(prefix) :]
.replace(self.extra_config_values_env_key_sep, ".")
.lower()
)
result = self._merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
return result
def _get_paths(self) -> List[Path]:
default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS)
value = getenv(EXTRA_CONFIG_PATH_OVERRIDE_VAR, default_paths)
value = first(map(getenv, self.extra_config_path_override_var), default_paths)
paths = [
Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP)
@@ -100,7 +112,7 @@ class BasicConfig:
invalid = [path for path in paths if not path.is_dir()]
if invalid:
print(
f"WARNING: Invalid paths in {EXTRA_CONFIG_PATH_OVERRIDE_VAR} env var: {' '.join(map(str, invalid))}"
f"WARNING: Invalid paths in {self.extra_config_path_override_var} env var: {' '.join(map(str, invalid))}"
)
return [path for path in paths if path.is_dir()]
@@ -114,13 +126,40 @@ class BasicConfig:
configs = [self._read_recursive(path) for path in self._paths]
return reduce(
lambda last, config: ConfigTree.merge_configs(
lambda last, config: self._merge_configs(
last, config, copy_trees=True
),
configs + [extra_config_values],
ConfigTree(),
)
@classmethod
def _merge_configs(cls, a, b, copy_trees=False, override_prefix="-"):
"""Based on pyhocon.ConfigTree.merge_configs, with dict override support using a `-` key prefix"""
for key, value in b.items():
override = key.startswith(override_prefix)
if override:
key = key[len(override_prefix):]
# if key is in both a and b and both values are dictionary then merge it otherwise override it
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
if copy_trees:
a[key] = a[key].copy()
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
else:
if isinstance(value, ConfigValues):
value.parent = a
value.key = key
if key in a:
value.overriden_value = a[key]
a[key] = value
if a.root:
if b.root:
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
else:
a.history[key] = a.history.get(key, []) + [value]
return a
def _read_recursive(self, conf_root) -> ConfigTree:
conf = ConfigTree()

View File

@@ -69,7 +69,7 @@
default_expiration_sec: 2592000
# cookie containing auth token, for requests arriving from a web-browser
session_auth_cookie_name: "trains_token_basic"
session_auth_cookie_name: "clearml_token_basic"
# cookie configuration for authorization cookies generated by auth.login
cookies {
@@ -80,8 +80,10 @@
}
# # A list of fixed users
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
# fixed_users {
# enabled: true
# pass_hashed: false
# users: [
# {
# username: "john"
@@ -116,9 +118,9 @@
# Check for updates every 24 hours
check_interval_sec: 86400
url: "https://updates.trains.allegro.ai/updates"
url: "https://updates.clear.ml/updates"
component_name: "trains-server"
component_name: "clearml-server"
# GET request timeout
request_timeout_sec: 3.0
@@ -128,7 +130,7 @@
# Note: statistics are sent ONLY if the user has actively opted-in
supported: true
url: "https://updates.trains.allegro.ai/stats"
url: "https://updates.clear.ml/stats"
report_interval_hours: 24
agent_relevant_threshold_days: 30

View File

@@ -16,7 +16,7 @@
backupCount: 3
maxBytes: 10240000,
class: "logging.handlers.RotatingFileHandler",
filename: "/var/log/trains/apiserver.log"
filename: "/var/log/clearml/apiserver.log"
}
}
root {

View File

@@ -28,6 +28,7 @@
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
revoke_in_fixed_mode: true
}
}
}

View File

@@ -10,4 +10,9 @@ featured {
# default featured index for public projects not specified in the order
public_default: 9999
}
sub_projects {
# the max sub project depth
max_depth: 10
}

View File

@@ -9,3 +9,14 @@ non_responsive_tasks_watchdog {
}
multi_task_histogram_limit: 100
hyperparam_values {
# maximal amount of distinct hyperparam values to retrieve
max_count: 100
# max allowed outdate time for the cashed result
cache_allowed_outdate_sec: 60
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -2,6 +2,8 @@ from functools import lru_cache
from os import getenv
from pathlib import Path
from boltons.iterutils import first
from apiserver.config_repo import config
from apiserver.version import __version__
@@ -9,7 +11,9 @@ root = Path(__file__).parent.parent
def _get(prop_name, env_suffix=None, default=""):
value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}")
suffix = env_suffix or prop_name
keys = [f"{p}_SERVER_{suffix}" for p in ("CLEARML", "TRAINS")]
value = first(map(getenv, keys))
if value:
return value

View File

@@ -17,11 +17,16 @@ log = config.logger("database")
strict = config.get("apiserver.mongo.strict", True)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_MONGODB_SERVICE_HOST",
"TRAINS_MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_MONGODB_SERVICE_PORT",
"TRAINS_MONGODB_SERVICE_PORT",
"MONGODB_SERVICE_PORT",
)
class DatabaseEntry(models.Base):
@@ -32,6 +37,10 @@ class DatabaseEntry(models.Base):
class DatabaseFactory:
_entries = []
@classmethod
def _create_db_entry(cls, alias: str, settings: dict) -> DatabaseEntry:
return DatabaseEntry(alias=alias, **settings)
@classmethod
def initialize(cls):
db_entries = config.get("hosts.mongo", {})
@@ -51,7 +60,7 @@ class DatabaseFactory:
missing.append(key)
continue
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
@@ -64,13 +73,15 @@ class DatabaseFactory:
log.info(
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
)
register_connection(alias=alias, host=entry.host)
register_connection(**entry.to_struct())
cls._entries.append(entry)
except ValidationError as ex:
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
if missing:
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
raise ValueError(
"Missing database configuration for %s" % ", ".join(missing)
)
@classmethod
def get_entries(cls):
@@ -91,7 +102,7 @@ class DatabaseFactory:
# reconnection from work so workaround this
# get_connection(entry.alias, reconnect=True)
disconnect(entry.alias)
register_connection(alias=entry.alias, host=entry.host)
register_connection(**entry.to_struct())
get_connection(entry.alias)

View File

@@ -1,7 +1,7 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type, Tuple
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
from boltons.iterutils import first, bucketize, partition
from dateutil.parser import parse as parse_datetime
@@ -86,6 +86,7 @@ class GetMixin(PropsMixin):
list_fields=("tags", "system_tags", "id"),
datetime_fields=None,
fields=None,
range_fields=None,
):
"""
:param pattern_fields: Fields for which a "string contains" condition should be generated
@@ -97,6 +98,7 @@ class GetMixin(PropsMixin):
self.fields = fields
self.datetime_fields = datetime_fields
self.list_fields = list_fields
self.range_fields = range_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
@@ -183,6 +185,53 @@ class GetMixin(PropsMixin):
parameters, parameters_options
) & cls._prepare_perm_query(company, allow_public=allow_public)
@staticmethod
def _pop_matching_params(
patterns: Sequence[str], parameters: dict
) -> Mapping[str, Any]:
"""
Pop the parameters that match the specified patterns and return
the dictionary of matching parameters
Pop None parameters since they are not the real queries
"""
if not patterns:
return {}
fields = set()
for pattern in patterns:
if pattern.endswith("*"):
prefix = pattern[:-1]
fields.update(
{field for field in parameters if field.startswith(prefix)}
)
elif pattern in parameters:
fields.add(pattern)
pairs = ((field, parameters.pop(field, None)) for field in fields)
return {k: v for k, v in pairs if v is not None}
@classmethod
def _try_convert_to_numeric(cls, value: Union[str, Sequence[str]]):
def convert_str(val: str) -> Union[float, str]:
try:
return float(val)
except ValueError:
return val
if isinstance(value, str):
return convert_str(value)
if isinstance(value, (list, tuple)):
return [convert_str(v) if isinstance(v, str) else v for v in value]
return value
@classmethod
def _get_fixed_field_value(cls, field: str, value):
if field.startswith("last_metrics."):
return cls._try_convert_to_numeric(value)
return value
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@@ -205,17 +254,24 @@ class GetMixin(PropsMixin):
dict_query = {}
query = RegexQ()
if parameters:
parameters = parameters.copy()
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.range_fields, parameters=parameters
).items():
query &= cls.get_range_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
@@ -250,15 +306,53 @@ class GetMixin(PropsMixin):
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
)
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
data.fields,
Q()
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
)
query = query & q
return query & RegexQ(**dict_query)
@classmethod
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""
Return a range query for the provided field. The data should contain min and max values
Both intervals are included. For open range queries either min or max can be None
In case the min value is None the records with missing or None value from db are included
"""
if not isinstance(data, (list, tuple)) or len(data) != 2:
raise errors.bad_request.ValidationError(
f"Min and max values should be specified for range field {field}"
)
min_val, max_val = data
if min_val is None and max_val is None:
raise errors.bad_request.ValidationError(
f"At least one of min or max values should be provided for field {field}"
)
mongoengine_field = field.replace(".", "__")
query = {}
if min_val is not None:
query[f"{mongoengine_field}__gte"] = min_val
if max_val is not None:
query[f"{mongoengine_field}__lte"] = max_val
q = Q(**query)
if min_val is None:
q |= Q(**{mongoengine_field: None})
return q
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""
@@ -271,7 +365,8 @@ class GetMixin(PropsMixin):
- AND can be achieved using a preceding "__$all" or "__$and" value (operator)
"""
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
data = [data]
# raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
@@ -285,11 +380,7 @@ class GetMixin(PropsMixin):
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
)
if not allow_empty:
@@ -448,6 +539,12 @@ class GetMixin(PropsMixin):
return helper.project(results, projection_func)
@classmethod
def _get_collation_override(cls, field: str) -> Optional[dict]:
return first(
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
)
@classmethod
def get_many(
cls,
@@ -485,6 +582,13 @@ class GetMixin(PropsMixin):
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:return: A list of objects matching the query.
"""
override_collation = None
if query_dict:
for field in query_dict:
override_collation = cls._get_collation_override(field)
if override_collation:
break
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
@@ -501,10 +605,14 @@ class GetMixin(PropsMixin):
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
return cls._get_many_no_company(
query=_query, parameters=parameters, override_projection=override_projection
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
@classmethod
@@ -528,6 +636,7 @@ class GetMixin(PropsMixin):
query: Q,
parameters=None,
override_projection=None,
override_collation=None,
):
"""
Fetch all documents matching a provided query.
@@ -547,12 +656,16 @@ class GetMixin(PropsMixin):
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
if order_by and not override_collation:
override_collation = cls._get_collation_override(order_by[0])
page, page_size = cls.validate_paging(parameters=parameters)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
qs = cls.objects(query)
if override_collation:
qs = qs.collation(collation=override_collation)
if search_text:
qs = qs.search_text(search_text)
if order_by:
@@ -572,12 +685,41 @@ class GetMixin(PropsMixin):
return qs
@classmethod
def _get_queries_for_order_field(
cls, query: Q, order_field: str
) -> Union[None, Tuple[Q, Q]]:
"""
In case the order_field is one of the cls fields and the sorting is ascending
then return the tuple of 2 queries:
1. original query with not empty constraint on the order_by field
2. original query with empty constraint on the order_by field
"""
if not order_field or order_field.startswith("-") or "[" in order_field:
return
mongo_field_name = order_field.replace(".", "__")
mongo_field = first(
v for k, v in cls.get_all_fields_with_instance() if k == mongo_field_name
)
if isinstance(mongo_field, ListField):
params = {"is_list": True}
elif isinstance(mongo_field, StringField):
params = {"empty_value": ""}
else:
params = {}
non_empty = query & field_exists(mongo_field_name, **params)
empty = query & field_does_not_exist(mongo_field_name, **params)
return non_empty, empty
@classmethod
def _get_many_override_none_ordering(
cls: Union[Document, "GetMixin"],
query: Q = None,
parameters: dict = None,
override_projection: Collection[str] = None,
override_collation: dict = None,
) -> Sequence[dict]:
"""
Fetch all documents matching a provided query. For the first order by field
@@ -610,32 +752,17 @@ class GetMixin(PropsMixin):
order_field = first(
field for field in order_by if not field.startswith("$")
)
if (
order_field
and not order_field.startswith("-")
and "[" not in order_field
):
params = {}
mongo_field = order_field.replace(".", "__")
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
params["is_list"] = True
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params)
query_sets = [cls.objects(non_empty), cls.objects(empty)]
res = cls._get_queries_for_order_field(query, order_field)
if res:
query_sets = [cls.objects(q) for q in res]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(
v
for k, v in cls._field_collation_overrides.items()
if order_field.startswith(k)
)
if collation_override:
query_sets = [
qs.collation(collation=collation_override) for qs in query_sets
]
if order_field and not override_collation:
override_collation = cls._get_collation_override(order_field)
if override_collation:
query_sets = [
qs.collation(collation=override_collation) for qs in query_sets
]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]

View File

@@ -0,0 +1,44 @@
from typing import Sequence, Type
from mongoengine import EmbeddedDocument, StringField, Document
from pymongo import UpdateOne
from pymongo.collection import Collection
from apiserver.database.model.base import ProperDictMixin
class MetadataItem(EmbeddedDocument, ProperDictMixin):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
collection: Collection = cls._get_collection()
res = collection.update_one(
filter={"_id": _id},
update={
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
},
array_filters=[
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
],
upsert=False,
)
if len(items) == 1 and res.modified_count == 1:
return res.modified_count
requests = [
UpdateOne(
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
update={"$push": {"metadata": item}},
)
for item in items
]
res = collection.bulk_write(requests)
return 1 if res.modified_count else 0
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)

View File

@@ -1,9 +1,22 @@
from mongoengine import Document, StringField, DateTimeField, BooleanField
from typing import Sequence
from mongoengine import (
Document,
StringField,
DateTimeField,
BooleanField,
EmbeddedDocumentListField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from apiserver.database.fields import (
StrippedStringField,
SafeDictField,
SafeSortedListField,
)
from apiserver.database.model import DbModelMixin
from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.company import Company
from apiserver.database.model.project import Project
@@ -19,6 +32,9 @@ class Model(DbModelMixin, Document):
"parent",
"project",
"task",
"last_update",
"metadata.key",
"metadata.type",
("company", "framework"),
("company", "name"),
("company", "user"),
@@ -51,6 +67,7 @@ class Model(DbModelMixin, Document):
"task",
"parent",
),
datetime_fields=("last_update",),
)
id = StringField(primary_key=True)
@@ -69,7 +86,11 @@ class Model(DbModelMixin, Document):
design = SafeDictField()
labels = ModelLabels()
ready = BooleanField(required=True)
last_update = DateTimeField()
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
)

View File

@@ -1,4 +1,4 @@
from mongoengine import StringField, DateTimeField, IntField
from mongoengine import StringField, DateTimeField, IntField, ListField
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
@@ -10,13 +10,15 @@ class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
)
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"parent",
"path",
("company", "name"),
{
"name": "%s.project.main_text_index" % Database.backend,
@@ -34,7 +36,7 @@ class Project(AttributedDocument):
min_length=3,
sparse=True,
)
description = StringField(required=True)
description = StringField()
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))
@@ -44,3 +46,5 @@ class Project(AttributedDocument):
logo_url = StringField()
logo_blob = StringField(exclude_by_default=True)
company_origin = StringField(exclude_by_default=True)
parent = StringField(reference_field="Project")
path = ListField(StringField(required=True), exclude_by_default=True)

View File

@@ -1,3 +1,5 @@
from typing import Sequence
from mongoengine import (
Document,
EmbeddedDocument,
@@ -11,6 +13,7 @@ from apiserver.database.fields import StrippedStringField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.model.base import ProperDictMixin, GetMixin
from apiserver.database.model.company import Company
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.task.task import Task
@@ -32,6 +35,7 @@ class Queue(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
"indexes": ["metadata.key", "metadata.type"],
}
id = StringField(primary_key=True)
@@ -44,3 +48,6 @@ class Queue(DbModelMixin, Document):
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
)

View File

@@ -11,6 +11,5 @@ class Result(object):
class Output(EmbeddedDocument):
destination = StrippedStringField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Sequence
from mongoengine import (
StringField,
@@ -17,6 +17,7 @@ from apiserver.database.fields import (
SafeDictField,
UnionField,
SafeSortedListField,
EmbeddedDocumentListField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@@ -79,7 +80,9 @@ DEFAULT_ARTIFACT_MODE = ArtifactModes.output
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
mode = StringField(
choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE
)
uri = StringField()
hash = StringField()
content_size = LongField()
@@ -103,17 +106,37 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
description = StringField()
class TaskModelTypes:
input = "input"
output = "output"
TaskModelNames = {
TaskModelTypes.input: "Input Model",
TaskModelTypes.output: "Output Model",
}
class ModelItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
model = StringField(required=True, reference_field="Model")
updated = DateTimeField()
class Models(EmbeddedDocument, ProperDictMixin):
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
model_desc = SafeMapField(StringField(default=""))
model_labels = ModelLabels()
framework = StringField()
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
docker_cmd = StringField()
queue = StringField()
queue = StringField(reference_field="Queue")
""" Queue ID where task was queued """
@@ -140,7 +163,6 @@ class Task(AttributedDocument):
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
}
meta = {
@@ -153,6 +175,7 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"models.input.model",
("company", "name"),
("company", "user"),
("company", "status", "type"),
@@ -160,14 +183,15 @@ class Task(AttributedDocument):
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{"fields": ["company", "project"], "collation": _numeric_locale},
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
"$name",
"$id",
"$comment",
"$execution.model",
"$output.model",
"$models.input.model",
"$models.output.model",
"$script.repository",
"$script.entry_point",
],
@@ -176,8 +200,8 @@ class Task(AttributedDocument):
"name": 10,
"id": 10,
"comment": 10,
"execution.model": 2,
"output.model": 2,
"models.output.model": 2,
"models.input.model": 2,
"script.repository": 1,
"script.entry_point": 1,
},
@@ -185,8 +209,18 @@ class Task(AttributedDocument):
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"),
datetime_fields=("status_changed",),
list_fields=(
"id",
"user",
"tags",
"system_tags",
"type",
"status",
"project",
"parent",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment"),
)
@@ -225,7 +259,11 @@ class Task(AttributedDocument):
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)
docker_init_script = StringField()
models: Models = EmbeddedDocumentField(Models, default=Models)
container = SafeMapField(field=StringField(default=""))
enqueue_status = StringField(
choices=get_options(TaskStatus), exclude_by_default=True
)
def get_index_company(self) -> str:
"""

View File

@@ -1,12 +1,11 @@
from collections import OrderedDict, defaultdict
from itertools import chain
from collections import OrderedDict
from operator import attrgetter
from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document, BaseField
from mongoengine.base import get_document
from apiserver.database.fields import (
LengthRangeEmbeddedDocumentListField,
@@ -21,7 +20,7 @@ class PropsMixin(object):
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_field_names_per_type = None
__cached_all_fields_with_instance = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
@@ -33,37 +32,12 @@ class PropsMixin(object):
return cls.__cached_fields
@classmethod
def get_field_names_for_type(cls, of_type=BaseField):
"""
Return field names per type including subfields
The fields of derived types are also returned
"""
assert issubclass(of_type, BaseField)
if cls.__cached_field_names_per_type is None:
fields = defaultdict(list)
for name, field in get_fields(cls, return_instance=True, subfields=True):
fields[type(field)].append(name)
for type_ in fields:
fields[type_].extend(
chain.from_iterable(
fields[other_type]
for other_type in fields
if other_type != type_ and issubclass(other_type, type_)
)
)
cls.__cached_field_names_per_type = fields
if of_type not in cls.__cached_field_names_per_type:
names = list(
chain.from_iterable(
field_names
for type_, field_names in cls.__cached_field_names_per_type.items()
if issubclass(type_, of_type)
)
def get_all_fields_with_instance(cls):
if cls.__cached_all_fields_with_instance is None:
cls.__cached_all_fields_with_instance = get_fields(
cls, return_instance=True, subfields=True
)
cls.__cached_field_names_per_type[of_type] = names
return cls.__cached_field_names_per_type[of_type]
return cls.__cached_all_fields_with_instance
@classmethod
def get_fields_with_instance(cls, doc_cls):

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from os import getenv
from typing import Tuple
from boltons.iterutils import first
from elasticsearch import Elasticsearch, Transport
@@ -9,11 +10,16 @@ from apiserver.config_repo import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_ELASTIC_SERVICE_HOST",
"TRAINS_ELASTIC_SERVICE_HOST",
"ELASTIC_SERVICE_HOST",
"ELASTIC_SERVICE_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_ELASTIC_SERVICE_PORT", "ELASTIC_SERVICE_PORT")
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_ELASTIC_SERVICE_PORT",
"TRAINS_ELASTIC_SERVICE_PORT",
"ELASTIC_SERVICE_PORT",
)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
@@ -70,6 +76,10 @@ class ESFactory:
def get_all_cluster_names(cls):
return list(config.get("hosts.elastic"))
@classmethod
def get_override(cls, cluster_name: str) -> Tuple[str, str]:
return OVERRIDE_HOST, OVERRIDE_PORT
@classmethod
def get_cluster_config(cls, cluster_name):
"""
@@ -84,14 +94,16 @@ class ESFactory:
raise MissingClusterConfiguration(cluster_name)
def set_host_prop(key, value):
for host in cluster_config.get("hosts", []):
host[key] = value
for entry in cluster_config.get("hosts", []):
entry[key] = value
if OVERRIDE_HOST:
set_host_prop("host", OVERRIDE_HOST)
host, port = cls.get_override(cluster_name)
if OVERRIDE_PORT:
set_host_prop("port", OVERRIDE_PORT)
if host:
set_host_prop("host", host)
if port:
set_host_prop("port", port)
return cluster_config
@@ -120,7 +132,9 @@ class ESFactory:
@classmethod
def get_es_timestamp_str(cls):
now = datetime.utcnow()
return now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
return (
now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
)
es_factory = ESFactory()

View File

@@ -4,19 +4,20 @@ from logging import Logger
from pathlib import Path
from mongoengine.connection import get_db
from semantic_version import Version
from packaging.version import Version, parse
from apiserver.database import utils
from apiserver.database import Database
from apiserver.database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
_migrations = "migrations"
_parent_dir = Path(__file__).resolve().parents[1]
_migration_dir = _parent_dir / _migrations
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in utils.get_options(Database)
get_db(alias).collection_names() for alias in utils.get_options(Database)
)
@@ -41,8 +42,8 @@ def _apply_migrations(log: Logger):
log.info(f"Started mongodb migrations")
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
if not _migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {_migration_dir}")
empty_dbs = check_mongo_empty()
last_version = get_last_server_version()
@@ -50,7 +51,10 @@ def _apply_migrations(log: Logger):
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
for ver, path in (
(parse(f.stem.replace("_", ".")), f)
for f in _migration_dir.glob("*.py")
)
if ver > last_version
}
except ValueError as ex:
@@ -64,7 +68,10 @@ def _apply_migrations(log: Logger):
if empty_dbs:
log.info(f"Skipping migration {script.name} (empty databases)")
else:
spec = importlib.util.spec_from_file_location(script.stem, str(script))
spec = importlib.util.spec_from_file_location(
".".join(("apiserver", _parent_dir.name, _migrations, script.stem)),
str(script),
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
@@ -83,7 +90,7 @@ def _apply_migrations(log: Logger):
DatabaseVersion(
id=utils.id(),
num=script.stem,
num=str(script_version),
created=datetime.utcnow(),
desc="Applied on server startup",
).save()

View File

@@ -25,7 +25,6 @@ from typing import (
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from boltons.iterutils import chunked_iter, first
from furl import furl
@@ -33,6 +32,7 @@ from mongoengine import Q
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
from apiserver.bll.project import project_ids_with_children
from apiserver.bll.task.artifacts import get_artifact_id
from apiserver.bll.task.param_utils import (
split_param_name,
@@ -44,11 +44,16 @@ from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
from apiserver.database.model.task.task import (
Task,
ArtifactModes,
TaskStatus,
TaskModelTypes,
TaskModelNames,
)
from apiserver.database.utils import get_options
from apiserver.tools import safe_get
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
class PrePopulate:
@@ -343,31 +348,10 @@ class PrePopulate:
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
Build content array that contains upgraded tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
@@ -423,6 +407,24 @@ class PrePopulate:
items.append(results[0])
return items
@classmethod
def _check_projects_hierarchy(cls, projects: Set[Project]):
"""
For any exported project all its parents up to the root should be present
"""
if not projects:
return
project_ids = {p.id for p in projects}
orphans = [p.id for p in projects if p.parent and p.parent not in project_ids]
if not orphans:
return
print(
f"ERROR: the following projects are exported without their parents: {orphans}"
)
exit(1)
@classmethod
def _resolve_entities(
cls,
@@ -434,6 +436,7 @@ class PrePopulate:
if projects:
print("Reading projects...")
projects = project_ids_with_children(projects)
entities[cls.project_cls].update(
cls._resolve_type(cls.project_cls, projects)
)
@@ -463,12 +466,16 @@ class PrePopulate:
project_ids = {p.id for p in entities[cls.project_cls]}
entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
model_ids = {
model_id
cls._check_projects_hierarchy(entities[cls.project_cls])
task_models = chain.from_iterable(
models
for task in entities[cls.task_cls]
for model_id in (task.output.model, task.execution.model)
if model_id
}
if task.models
for models in (task.models.input, task.models.output)
if models
)
model_ids = {tm.model for tm in task_models}
if model_ids:
print("Reading models...")
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
@@ -634,11 +641,12 @@ class PrePopulate:
"""
Export the requested experiments, projects and models and return the list of artifact files
Always do the export on sorted items since the order of items influence hash
The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom
"""
artifacts = []
now = datetime.utcnow()
for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("id"))
items = sorted(entities[cls_], key=attrgetter("name", "id"))
if not items:
continue
base_filename = cls._get_base_filename(cls_)
@@ -735,6 +743,77 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _upgrade_task_data(task_data: dict) -> dict:
"""
Migrate from execution/parameters and model_desc to hyperparams and configuration fiields
Upgrade artifacts list to dict
Migrate from execution.model and output.model to the new models field
Move docker_cmd contents into the container field
:param task_data: Upgraded in place
:return: The upgraded task data
"""
for old_param_field, new_param_field, default_section in (
("execution.parameters", "hyperparams", hyperparams_default_section),
("execution.model_desc", "configuration", None),
):
legacy_path = old_param_field.split(".")
legacy = nested_get(task_data, legacy_path)
if legacy:
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not nested_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
nested_set(task_data, path=new_path, value=new_param)
nested_delete(task_data, legacy_path)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
if isinstance(artifacts, list):
nested_set(
task_data,
path=artifacts_path,
value={get_artifact_id(a): a for a in artifacts},
)
models = task_data.get("models", {})
now = datetime.utcnow()
for old_field, type_ in (
("execution.model", TaskModelTypes.input),
("output.model", TaskModelTypes.output),
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = models.get(type_, [])
name = TaskModelNames[type_]
if old_model and not any(
m
for m in new_models
if m.get("model") == old_model or m.get("name") == name
):
model_item = {"model": old_model, "name": name, "updated": now}
if type_ == TaskModelTypes.input:
new_models = [model_item, *new_models]
else:
new_models = [*new_models, model_item]
models[type_] = new_models
nested_delete(task_data, old_path)
task_data["models"] = models
docker_cmd_path = ("execution", "docker_cmd")
docker_cmd = nested_get(task_data, docker_cmd_path)
if docker_cmd and not task_data.get("container"):
image, _, arguments = docker_cmd.partition(" ")
task_data["container"] = {"image": image, "arguments": arguments}
nested_delete(task_data, docker_cmd_path)
return task_data
@classmethod
def _import_entity(
cls,
@@ -750,16 +829,7 @@ class PrePopulate:
override_project_count = 0
for item in cls.json_lines(f):
if cls_ == cls.task_cls:
task_data = json.loads(item)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
if isinstance(artifacts, list):
nested_set(
task_data,
artifacts_path,
value={get_artifact_id(a): a for a in artifacts},
)
item = json.dumps(task_data)
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):

View File

@@ -5,8 +5,7 @@ from pymongo.database import Database, Collection
def migrate_auth(db: Database):
collection: Collection = db["user"]
if "name_1_company_1" in [doc["name"] for doc in collection.list_indexes()]:
collection.drop_index("name_1_company_1")
collection.drop_indexes()
def migrate_backend(db: Database):

View File

@@ -31,8 +31,8 @@ def migrate_auth(db: Database):
if not uuids:
return
collection = db["user"]
collection.drop_index("name_1_company_1")
collection: Collection = db["user"]
collection.drop_indexes()
_switch_uuid(collection=collection, uuid_field="_id", uuids=uuids)

View File

@@ -1,15 +1,6 @@
from collections import Collection
from typing import Sequence
from pymongo.database import Database
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
for collection_name in db.list_collection_names():
if collection_name not in names:
continue
collection: Collection = db[collection_name]
collection.drop_indexes()
from .utils import _drop_all_indices_from_collections
def migrate_auth(db: Database):

View File

@@ -0,0 +1,145 @@
import os
import re
from datetime import datetime
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.services.utils import escape_dict
from apiserver.utilities.dicts import nested_get
from .utils import _drop_all_indices_from_collections
def _migrate_task_models(db: Database):
"""
Move the execution and output models to new models.input and output lists
"""
tasks: Collection = db["task"]
models_field = "models"
now = datetime.utcnow()
fields = {
TaskModelTypes.input: "execution.model",
TaskModelTypes.output: "output.model",
}
query = {"$or": [{field: {"$exists": True}} for field in fields.values()]}
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
set_commands = {}
for mode, field in fields.items():
value = nested_get(doc, field.split("."))
if value:
name = TaskModelNames[mode]
model_item = {"model": value, "name": name, "updated": now}
existing_models = nested_get(doc, (models_field, mode), default=[])
existing_models = (
m
for m in existing_models
if m.get("name") != name and m.get("model") != value
)
if mode == TaskModelTypes.input:
updated_models = [model_item, *existing_models]
else:
updated_models = [*existing_models, model_item]
set_commands[f"{models_field}.{mode}"] = updated_models
tasks.update_one(
{"_id": doc["_id"]},
{
"$unset": {field: 1 for field in fields.values()},
**({"$set": set_commands} if set_commands else {}),
},
)
def _migrate_docker_cmd(db: Database):
tasks: Collection = db["task"]
docker_cmd_field = "execution.docker_cmd"
query = {docker_cmd_field: {"$exists": True}}
for doc in tasks.find(filter=query, projection=(docker_cmd_field,)):
set_commands = {}
docker_cmd = nested_get(doc, docker_cmd_field.split("."))
if docker_cmd:
image, _, arguments = docker_cmd.partition(" ")
set_commands["container"] = {"image": image, "arguments": arguments}
tasks.update_one(
{"_id": doc["_id"]},
{
"$unset": {docker_cmd_field: 1},
**({"$set": set_commands} if set_commands else {}),
},
)
def _migrate_model_labels(db: Database):
tasks: Collection = db["task"]
fields = ("execution.model_labels", "container")
query = {"$or": [{field: {"$nin": [None, {}]}} for field in fields]}
for doc in tasks.find(filter=query, projection=fields):
set_commands = {}
for field in fields:
data = nested_get(doc, field.split("."))
if not data:
continue
escaped = escape_dict(data)
if data == escaped:
continue
set_commands[field] = escaped
if set_commands:
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
def _migrate_project_description(db: Database):
projects: Collection = db["project"]
filter = {
"$or": [
{
"$expr": {"$lt": [{"$strLenCP": "$description"}, 100]},
"description": {"$regex": "^Auto-generated at ", "$options": "i"},
},
{"description": {"$regex": "^Auto-generated during move$", "$options": "i"}},
{"description": {"$regex": "^Auto-generated while cloning$", "$options": "i"}},
]
}
for doc in projects.find(filter=filter):
projects.update_one({"_id": doc["_id"]}, {"$unset": {"description": 1}})
def _migrate_project_names(db: Database):
projects: Collection = db["project"]
regx = re.compile("/", re.IGNORECASE)
for doc in projects.find(filter={"name": regx, "path": {"$in": [None, []]}}):
name = doc.get("name")
if not name:
continue
max_tries = int(os.getenv("CLEARML_MIGRATION_PROJECT_RENAME_MAX_TRIES", 10))
iteration = 0
for iteration in range(max_tries):
new_name = name.replace("/", "_" * (iteration + 1))
try:
projects.update_one({"_id": doc["_id"]}, {"$set": {"name": new_name}})
break
except DuplicateKeyError:
pass
if iteration >= max_tries - 1:
print(f"Could not upgrade the name {name} of the project {doc.get('_id')}")
def migrate_backend(db: Database):
_migrate_task_models(db)
_migrate_docker_cmd(db)
_migrate_model_labels(db)
_migrate_project_names(db)
_migrate_project_description(db)
_drop_all_indices_from_collections(db, ["task*"])

View File

@@ -0,0 +1,20 @@
from typing import Sequence
from boltons.iterutils import partition
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
"""
Drop all indices for the existing collections from the specified list
"""
prefixes, names = partition(names, key=lambda x: x.endswith("*"))
prefixes = {p.rstrip("*") for p in prefixes}
for collection_name in db.list_collection_names():
if not (
collection_name in names
or any(p for p in prefixes if collection_name.startswith(p))
):
continue
collection: Collection = db[collection_name]
collection.drop_indexes()

View File

@@ -11,8 +11,16 @@ from apiserver.config_repo import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = ("TRAINS_REDIS_SERVICE_HOST", "REDIS_SERVICE_HOST")
OVERRIDE_PORT_ENV_KEY = ("TRAINS_REDIS_SERVICE_PORT", "REDIS_SERVICE_PORT")
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_REDIS_SERVICE_HOST",
"TRAINS_REDIS_SERVICE_HOST",
"REDIS_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PORT",
"TRAINS_REDIS_SERVICE_PORT",
"REDIS_SERVICE_PORT",
)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:

View File

@@ -1,4 +1,5 @@
attrs>=19.1.0
bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
dpath>=1.4.2,<2.0
@@ -11,12 +12,13 @@ funcsigs==1.0.2
furl>=2.0.0
gunicorn>=19.7.1
humanfriendly==4.18
jinja2==2.10
jinja2==2.11.3
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.19.1
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt<2.0.0

View File

@@ -1,3 +1,20 @@
metadata_item {
type: object
properties {
key {
type: string
description: The key uniquely identifying the metadata item inside the given entity
}
type {
type: string
description: The type of the metadata item
}
value {
type: string
description: The value stored in the metadata item
}
}
}
credentials {
type: object
properties {
@@ -11,3 +28,61 @@ credentials {
}
}
}
batch_operation {
request {
type: object
required: [ids]
properties {
ids {
type: array
items {type: string}
}
}
}
response {
type: object
properties {
succeeded {
type: array
items {
type: object
properties {
id: {
description: ID of the succeeded entity
type: string
}
}
}
}
failed {
type: array
items {
type: object
properties {
id: {
description: ID of the failed entity
type: string
}
error: {
description: Error info
type: object
properties {
codes {
type: array
items {type: integer}
}
msg {
type: string
}
data {
type: object
additionalProperties: True
}
}
}
}
}
}
}
}
}

View File

@@ -187,16 +187,12 @@
}
task_metric {
type: object
required: [task, metric]
required: [task]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
}
}
task_log_event {
@@ -370,7 +366,7 @@
}
}
"2.7" {
description: "Get the debug image events for the requested amount of iterations per each task's metric"
description: "Get the debug image events for the requested amount of iterations per each task"
request {
type: object
required: [

View File

@@ -85,7 +85,27 @@ supported_modes {
}
}
}
authenticated {
description: "Is user authenticated"
type: boolean
}
}
}
}
}
logout {
authorize: false
allow_roles = [ "*" ]
"2.13" {
description: """ Logout (including SSO, if used)) """
request {
type: object
additionalProperties: false
}
response {
type: object
additionalProperties: false
}
}
}

View File

@@ -1,5 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
@@ -38,6 +39,11 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Model last update time"
type: string
format: "date-time"
}
task {
description: "Task ID of task in which the model was created"
type: string
@@ -91,6 +97,37 @@ _definitions {
type: object
additionalProperties: true
}
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
}
}
}
published_task_item {
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
type: object
properties {
id {
description: "Task id"
type: string
}
data {
description: "Data returned from the task publishing operation."
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
}
@@ -151,6 +188,17 @@ get_by_id_ex {
get_all_ex {
internal: true
"2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
}
get_all {
"2.1" {
@@ -433,6 +481,13 @@ create {
}
}
}
"2.13": ${create."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
}
}
}
edit {
"2.1" {
@@ -521,6 +576,13 @@ edit {
}
}
}
"2.13": ${edit."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
}
}
}
update {
"2.1" {
@@ -597,6 +659,40 @@ update {
}
}
}
"2.13": ${update."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
}
}
}
publish_many {
"2.13": ${_definitions.batch_operation} {
description: Publish models
request {
properties {
ids.description: "IDs of models to publish"
force_publish_task {
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
type: boolean
}
publish_tasks {
description: "Indicates that the associated tasks (if exist) should be published. Optional, the default value is True."
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.updated {
description: "Indicates whether the model was updated"
type: boolean
}
succeeded.items.properties.published_task: ${_definitions.published_task_item}
}
}
}
}
set_ready {
"2.1" {
@@ -627,39 +723,69 @@ set_ready {
type: integer
enum: [0, 1]
}
published_task {
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
type: object
properties {
id {
description: "Task id"
type: string
}
data {
description: "Data returned from the task publishing operation."
type: object
properties {
committed_versions_results {
description: "Committed versions results"
type: array
items {
type: object
additionalProperties: true
}
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
published_task: ${_definitions.published_task_item}
}
}
}
}
archive_many {
"2.13": ${_definitions.batch_operation} {
description: Archive models
request {
properties {
ids.description: "IDs of models to archive"
}
}
response {
properties {
succeeded.items.properties.archived {
description: "Indicates whether the model was archived"
type: boolean
}
}
}
}
}
unarchive_many {
"2.13": ${_definitions.batch_operation} {
description: Unarchive models
request {
properties {
ids.description: "IDs of the models to unarchive"
}
}
response {
properties {
succeeded.items.properties.unarchived {
description: "Indicates whether the model was unarchived"
type: boolean
}
}
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
description: Delete models
request {
properties {
ids.description: "IDs of models to delete"
force {
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
"""
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.deleted {
description: "Indicates whether the model was deleted"
type: boolean
}
succeeded.items.properties.url {
description: "The url of the model file"
type: string
}
}
}
@@ -697,6 +823,16 @@ delete {
}
}
}
"2.13": ${delete."2.1"} {
response {
properties {
url {
description: "The url of the model file"
type: string
}
}
}
}
}
make_public {
@@ -777,4 +913,63 @@ move {
}
}
}
add_or_update_metadata {
"2.13" {
description: "Add or update model metadata"
request {
type: object
required: [model, metadata]
properties {
model {
description: "ID of the model"
type: string
}
metadata {
type: array
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
delete_metadata {
"2.13" {
description: "Delete metadata from model"
request {
type: object
required: [ model, keys ]
properties {
model {
description: "ID of the model"
type: string
}
keys {
description: "The list of metadata keys to delete"
type: array
items {type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}

View File

@@ -167,10 +167,27 @@ _definitions {
type: string
}
// extra properties
stats: {
stats {
description: "Additional project stats"
"$ref": "#/definitions/stats"
}
sub_projects {
description: "The list of sub projects"
type: array
items {
type: object
properties {
id {
description: "Subproject ID"
type: string
}
name {
description: "Subproject name"
type: string
}
}
}
}
}
}
metric_variant_result {
@@ -242,6 +259,23 @@ _definitions {
}
}
}
urls {
type: object
properties {
model_urls {
type: array
items {type: string}
}
event_urls {
type: array
items {type: string}
}
artifact_urls {
type: array
items {type: string}
}
}
}
}
create {
@@ -249,17 +283,14 @@ create {
description: "Create a new project"
request {
type: object
required :[
name
description
]
required :[name]
properties {
name {
description: "Project name Unique within the company."
type: string
}
description {
description: "Project description. "
description: "Project description."
type: string
}
tags {
@@ -388,6 +419,17 @@ get_all {
}
}
}
"2.13": ${get_all."2.1"} {
request {
properties {
shallow_search {
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
type: boolean
default: false
}
}
}
}
}
get_all_ex {
internal: true
@@ -413,6 +455,39 @@ get_all_ex {
}
}
}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
shallow_search {
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
type: boolean
default: false
}
check_own_contents {
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
type: boolean
default: false
}
}
}
response {
properties {
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
own_models {
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
}
}
}
}
update {
"2.1" {
@@ -470,6 +545,66 @@ update {
}
}
}
move {
"2.13" {
description: "Moves a project and all of its subprojects under the different location"
request {
type: object
required: [project]
properties {
project {
description: "Project id"
type: string
}
new_location {
description: "The name location for the project"
type: string
}
}
}
response {
type: object
properties {
moved {
description: "The number of projects moved"
type: integer
}
}
}
}
}
merge {
"2.13" {
description: "Moves all the source project's contents to the destination project and remove the source project"
request {
type: object
required: [project]
properties {
project {
description: "Project id"
type: string
}
destination_project {
description: "The ID of the destination project"
type: string
}
}
}
response {
type: object
properties {
moved_entities {
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
type: integer
}
moved_projects {
description: "The number of child projects moved from the merged project into the destination"
type: integer
}
}
}
}
}
delete {
"2.1" {
description: "Deletes a project"
@@ -504,6 +639,32 @@ delete {
}
}
}
"2.13": ${delete."2.1"} {
request {
properties {
delete_contents {
description: "If set to 'true' then the project tasks and models will be deleted. Otherwise their project property will be unassigned. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by the project tasks and models. Returned if the 'delete_contents' was set to 'true'"
"$ref": "#/definitions/urls"
}
deleted_models {
description: "Number of models deleted"
type: integer
}
deleted_tasks {
description: "Number of tasks deleted"
type: integer
}
}
}
}
}
get_unique_metric_variants {
"2.1" {
@@ -530,6 +691,64 @@ get_unique_metric_variants {
}
}
}
"2.13": ${get_unique_metric_variants."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metrics/variants from the subproject tasks"
type: boolean
default: true
}
}
}
}
}
get_hyperparam_values {
"2.13" {
description: """Get a list of distinct values for the chosen hyperparameter"""
request {
type: object
required: [section, name]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
section {
description: "Hyperparameter section name"
type: string
}
name {
description: "Hyperparameter name"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters values from the subproject tasks"
type: boolean
default: true
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct parameter values"
type: integer
}
values {
description: "The list of the unique values for the parameter"
type: array
items {type: string}
}
}
}
}
}
get_hyper_parameters {
"2.9" {
@@ -572,6 +791,17 @@ get_hyper_parameters {
}
}
}
"2.13": ${get_hyper_parameters."2.9"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters from the subproject tasks"
type: boolean
default: true
}
}
}
}
}
get_task_tags {
@@ -641,7 +871,7 @@ make_private {
}
get_task_parents {
"2.12" {
description: "Get unique parent tasks for the tasks in the specified pprojects"
description: "Get unique parent tasks for the tasks in the specified projects"
request {
type: object
properties {
@@ -692,4 +922,15 @@ get_task_parents {
}
}
}
"2.13": ${get_task_parents."2.12"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the projects field is not empty then the result includes tasks parents from the subproject tasks"
type: boolean
default: true
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,49 @@ _references {
}
_definitions {
include "_common.conf"
change_many_request: ${_definitions.batch_operation} {
request {
properties {
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
response {
properties {
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
update_response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
multi_field_pattern_data {
type: object
properties {
@@ -40,6 +83,24 @@ _definitions {
}
}
}
model_type_enum {
type: string
enum: ["input", "output"]
}
task_model_item {
type: object
required: [ name, model]
properties {
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
}
}
script {
type: object
properties {
@@ -207,6 +268,22 @@ _definitions {
}
}
}
task_models {
type: object
properties {
input {
description: "The list of task input models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
output {
description: "The list of task output models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
}
}
execution {
type: object
properties {
@@ -454,6 +531,15 @@ _definitions {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
}
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
// TODO: will be removed
script {
description: "Script info"
@@ -531,6 +617,28 @@ _definitions {
"$ref": "#/definitions/configuration_item"
}
}
runtime {
description: "Task runtime mapping"
type: object
additionalProperties: true
}
}
}
task_urls {
type: object
properties {
model_urls {
type: array
items {type: string}
}
event_urls {
type: array
items {type: string}
}
artifact_urls {
type: array
items {type: string}
}
}
}
}
@@ -566,6 +674,17 @@ get_by_id_ex {
get_all_ex {
internal: true
"2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
}
get_all {
"2.1" {
@@ -805,6 +924,106 @@ clone {
}
}
}
"2.13": ${clone."2.12"}{
request {
properties {
new_task_input_models {
description: "The list of input models for the cloned task. If not specifed then copied from the original task"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: string }
}
}
}
}
}
add_or_update_model {
"2.13" {
description: "Add or update task model"
request {
type: object
required: [task, name, model, type]
properties {
task {
description: "ID of the task"
type: string
}
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
type {
description: "The task model type"
"$ref": "#/definitions/model_type_enum"
}
iteration {
description: "Iteration (used to update task statistics)"
type: integer
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
delete_models {
"2.13" {
description: "Delete models from task"
request {
type: object
required: [ task, models ]
properties {
task {
description: "ID of the task"
type: string
}
models {
description: "The list of models to delete"
type: array
items {
type: object
required: [name, type]
properties {
name {
description: "The task model name"
type: string
}
type {
description: "The task model type"
"$ref": "#/definitions/model_type_enum"
}
}
}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
create {
"2.1" {
@@ -884,6 +1103,21 @@ create {
}
}
}
"2.13": ${create."2.1"} {
request {
properties {
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
}
}
}
}
}
validate {
"2.1" {
@@ -958,6 +1192,21 @@ validate {
additionalProperties: false
}
}
"2.13": ${validate."2.1"} {
request {
properties {
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
}
}
}
}
}
update {
"2.1" {
@@ -1005,21 +1254,7 @@ update {
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
update_batch {
@@ -1117,16 +1352,22 @@ edit {
}
}
}
response {
type: object
response: ${_definitions.update_response}
}
"2.13": ${edit."2.1"} {
request {
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
fields {
description: "Updated fields names and values"
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
}
runtime {
description: "Task runtime mapping"
type: object
additionalProperties: true
}
@@ -1151,24 +1392,13 @@ reset {
default: false
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
deleted_indices {
description: "List of deleted ES indices that were removed as part of the reset process"
type: array
items { type: string }
}
dequeued {
description: "Response from queues.remove_task"
type: object
additionalProperties: true
}
frames {
description: "Response from frames.rollback"
type: object
additionalProperties: true
}
events {
description: "Response from events.delete_for_task"
type: object
@@ -1178,16 +1408,126 @@ reset {
description: "Number of output models deleted by the reset"
type: integer
}
updated {
}
}
}
"2.13": ${reset."2.1"} {
request {
properties {
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
reset_many {
"2.13": ${_definitions.batch_operation} {
description: Reset tasks
request {
properties {
ids.description: "IDs of the tasks to reset"
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'"
}
clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.dequeued {
description: "Indicates whether the task was dequeued"
type: boolean
}
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
succeeded.items.properties.deleted_models {
description: "Number of output models deleted by the reset"
type: integer
}
succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
description: Delete tasks
request {
properties {
ids.description: "IDs of the tasks to delete"
move_to_trash {
description: "Move task to trash instead of deleting it. For internal use only, tasks in the trash are not visible from the API and cannot be restored!"
type: boolean
default: false
}
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'in_progress'"
}
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.deleted {
description: "Indicates whether the task was deleted"
type: boolean
}
succeeded.items.properties.updated_children {
description: "Number of child tasks whose parent property was updated"
type: integer
}
succeeded.items.properties.updated_models {
description: "Number of models whose task property was updated"
type: integer
}
succeeded.items.properties.deleted_models {
description: "Number of deleted output models"
type: integer
}
succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
@@ -1229,15 +1569,6 @@ delete {
description: "Number of models whose task property was updated"
type: integer
}
updated_versions {
description: "Number of dataset versions whose task property was updated"
type: integer
}
frames {
description: "Response from frames.rollback"
type: object
additionalProperties: true
}
events {
description: "Response from events.delete_for_task"
type: object
@@ -1246,6 +1577,24 @@ delete {
}
}
}
"2.13": ${delete."2.1"} {
request {
properties {
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
archive {
"2.12" {
@@ -1284,6 +1633,58 @@ archive {
}
}
}
archive_many {
"2.13": ${_definitions.batch_operation} {
description: Archive tasks
request {
properties {
ids.description: "IDs of the tasks to archive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
response {
properties {
succeeded.items.properties.archived {
description: "Indicates whether the task was archived"
type: boolean
}
}
}
}
}
}
unarchive_many {
"2.13": ${_definitions.batch_operation} {
description: Unarchive tasks
request {
properties {
ids.description: "IDs of the tasks to unarchive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
response {
properties {
succeeded.items.properties.unarchived {
description: "Indicates whether the task was unarchived"
type: boolean
}
}
}
}
}
started {
"2.1" {
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
@@ -1296,24 +1697,13 @@ started {
description: "If not true, call fails if the task status is not 'not_started'"
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
started {
description: "Number of tasks started (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
@@ -1330,18 +1720,17 @@ stop {
description: "If not true, call fails if the task status is not 'in_progress'"
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response}
}
}
stop_many {
"2.13": ${_definitions.change_many_request} {
description: "Request to stop running tasks"
request {
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
ids.description: "IDs of the tasks to stop"
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is not 'in_progress'"
}
}
}
@@ -1359,21 +1748,7 @@ stopped {
description: "If not true, call fails if the task status is not 'stopped'"
}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
failed {
@@ -1386,21 +1761,7 @@ failed {
]
properties.force = ${_references.force_arg}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
close {
@@ -1413,21 +1774,7 @@ close {
]
properties.force = ${_references.force_arg}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
publish {
@@ -1452,26 +1799,21 @@ publish {
}
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response}
}
}
publish_many {
"2.13": ${_definitions.change_many_request} {
description: Publish tasks
request {
properties {
committed_versions_results {
description: "Committed versions results"
type: array
items {
type: object
additionalProperties: true
}
ids.description: "IDs of the tasks to publish"
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is not 'stopped'"
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
publish_model {
description: "Indicates that the task output model (if exists) should be published. Optional, the default value is True."
type: boolean
}
}
}
@@ -1502,23 +1844,39 @@ Fails if the following parameters in the task were not filled:
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
queued {
description: "Number of tasks queued (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {
description: Enqueue tasks
request {
properties {
ids.description: "IDs of the tasks to enqueue"
queue {
description: "Queue id. If not provided, tasks are added to the default queue."
type: string
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
validate_tasks {
description: "If set then tasks are validated before enqueue"
type: boolean
default: false
}
}
}
response {
properties {
succeeded.items.properties.queued {
description: "Indicates whether the task was queued"
type: boolean
}
}
}
@@ -1534,23 +1892,30 @@ dequeue {
task
]
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
dequeued {
description: "Number of tasks dequeued (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
dequeue_many {
"2.13": ${_definitions.change_many_request} {
description: Dequeue tasks
request {
properties {
ids.description: "IDs of the tasks to dequeue"
}
}
response {
properties {
succeeded.items.properties.dequeued {
description: "Indicates whether the task was dequeued"
type: boolean
}
}
}
@@ -1576,21 +1941,7 @@ set_requirements {
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
@@ -1606,21 +1957,7 @@ completed {
description: "If not true, call fails if the task status is not in_progress/stopped"
}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
@@ -1932,6 +2269,11 @@ get_configuration_names {
type: array
items { type: string }
}
skip_empty {
description: If set to 'true' then the names for configurations with missing values are not returned
type: boolean
default: true
}
}
}
response {

View File

@@ -1,6 +1,9 @@
from functools import partial
from flask import request, Response, redirect
from werkzeug.exceptions import BadRequest
from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
@@ -25,7 +28,10 @@ class RequestHandlers:
try:
call = self._create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
load_data_callback = partial(self._load_call_data, req=request)
content, content_type = ServiceRepo.handle_call(
call, load_data_callback=load_data_callback
)
if call.result.redirect:
response = redirect(call.result.redirect.url, call.result.redirect.code)
@@ -137,9 +143,6 @@ class RequestHandlers:
auth_cookie=auth_cookie,
)
# Update call data from request
self._update_call_data(call, req)
except PathParsingError as ex:
call = self._call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
@@ -156,3 +159,18 @@ class RequestHandlers:
)
return call
def _load_call_data(self, call: APICall, req):
"""Update call data from request"""
try:
self._update_call_data(call, req)
except BadRequest as ex:
call.set_error_result(msg=ex.description, code=400)
except BaseError as ex:
call.set_error_result(msg=ex.msg, code=ex.code, subcode=ex.subcode)
except APIError as ex:
call.set_error_result(
msg=ex.msg, code=ex.code, subcode=ex.subcode, error_data=ex.error_data
)
except Exception as ex:
call.set_error_result(msg=ex.args[0] if ex.args else type(ex).__name__)

View File

@@ -1,6 +1,6 @@
from typing import Text, Sequence, Callable, Union, Type
from funcsigs import signature
from inspect import signature
from jsonmodels import models
from .apicall import APICall, APICallResult

View File

@@ -5,6 +5,7 @@ from typing import Type, Optional, Union, Tuple
import attr
from jsonmodels import models
from requests.structures import CaseInsensitiveDict
from six import string_types
from apiserver import database
@@ -313,6 +314,13 @@ class APICall(DataContainer):
def HEADER_TRANSACTION(self):
return self._transaction_headers[0]
_client_headers = _get_headers("Client")
""" Client """
@property
def HEADER_CLIENT(self):
return self._client_headers[0]
_worker_headers = _get_headers("Worker")
""" Worker (machine) ID """
@@ -366,7 +374,7 @@ class APICall(DataContainer):
assert isinstance(endpoint_version, PartialVersion), endpoint_version
self._requested_endpoint_version = endpoint_version
self._actual_endpoint_version = None
self._headers = {}
self._headers = CaseInsensitiveDict()
self._kpis = {}
self._log_api = True
if headers:
@@ -379,6 +387,7 @@ class APICall(DataContainer):
self._requires_authorization = True
self._host = host
self._auth_cookie = auth_cookie
self._json_flags = {}
@property
def id(self):
@@ -420,7 +429,7 @@ class APICall(DataContainer):
:param header: Header name options (more than on supported, all will be cleared)
"""
for value in header if isinstance(header, (tuple, list)) else (header,):
self.headers.pop(value, None)
self._headers.pop(value, None)
def set_header(self, header, value):
"""
@@ -514,7 +523,7 @@ class APICall(DataContainer):
@property
def headers(self):
return self._headers
return dict(self._headers.items())
@property
def kpis(self):
@@ -532,6 +541,10 @@ class APICall(DataContainer):
def trx(self, value):
self.set_header(self._transaction_headers, value)
@property
def client(self):
return self.get_header(self._client_headers)
@property
def worker(self):
return self.get_worker(default="<unknown>")
@@ -567,6 +580,10 @@ class APICall(DataContainer):
def auth_cookie(self):
return self._auth_cookie
@property
def json_flags(self):
return self._json_flags
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
@@ -617,7 +634,8 @@ class APICall(DataContainer):
}
if self.content_type.lower() == JSON_CONTENT_TYPE:
try:
res = json.dumps(res)
func = json.dumps if self._json_flags.pop("ensure_ascii", True) else json.dumps_notascii
res = func(res, **(self._json_flags or {}))
except Exception as ex:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
if not (self.result.data or self.result.error_data):

View File

@@ -1,6 +1,7 @@
import base64
from datetime import datetime
import bcrypt
import jwt
from mongoengine import Q
@@ -31,8 +32,8 @@ def get_auth_func(auth_type):
def authorize_token(jwt_token, *_, **__):
""" Validate token against service/endpoint and requests data (dicts).
Returns a parsed token object (auth payload)
"""Validate token against service/endpoint and requests data (dicts).
Returns a parsed token object (auth payload)
"""
try:
return Token.from_encoded_token(jwt_token)
@@ -51,13 +52,15 @@ def authorize_token(jwt_token, *_, **__):
def authorize_credentials(auth_data, service, action, call_data_items):
""" Validate credentials against service/action and request data (dicts).
Returns a new basic object (auth payload)
"""Validate credentials against service/action and request data (dicts).
Returns a new basic object (auth payload)
"""
try:
access_key, _, secret_key = base64.b64decode(auth_data.encode()).decode('latin-1').partition(':')
access_key, _, secret_key = (
base64.b64decode(auth_data.encode()).decode("latin-1").partition(":")
)
except Exception as e:
log.exception('malformed credentials')
log.exception("malformed credentials")
raise errors.unauthorized.BadCredentials(str(e))
query = Q(credentials__match=Credentials(key=access_key, secret=secret_key))
@@ -67,18 +70,32 @@ def authorize_credentials(auth_data, service, action, call_data_items):
if FixedUser.enabled():
fixed_user = FixedUser.get_by_username(access_key)
if fixed_user:
if secret_key != fixed_user.password:
raise errors.unauthorized.InvalidCredentials('bad username or password')
if FixedUser.pass_hashed():
if not compare_secret_key_hash(secret_key, fixed_user.password):
raise errors.unauthorized.InvalidCredentials(
"bad username or password"
)
else:
if secret_key != fixed_user.password:
raise errors.unauthorized.InvalidCredentials(
"bad username or password"
)
if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action):
raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest')
raise errors.unauthorized.InvalidCredentials(
"endpoint not allowed for guest"
)
query = Q(id=fixed_user.user_id)
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
with TimingContext("mongo", "user_by_cred"), translate_errors_context(
"authorizing request"
):
user = User.objects(query).first()
if not user:
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
raise errors.unauthorized.InvalidCredentials(
"failed to locate provided credentials"
)
if not fixed_user:
# In case these are proper credentials, update last used time
@@ -87,13 +104,18 @@ def authorize_credentials(auth_data, service, action, call_data_items):
)
with TimingContext("mongo", "company_by_id"):
company = Company.objects(id=user.company).only('id', 'name').first()
company = Company.objects(id=user.company).only("id", "name").first()
if not company:
raise errors.unauthorized.InvalidCredentials('invalid user company')
raise errors.unauthorized.InvalidCredentials("invalid user company")
identity = Identity(user=user.id, company=user.company, role=user.role,
user_name=user.name, company_name=company.name)
identity = Identity(
user=user.id,
company=user.company,
role=user.role,
user_name=user.name,
company_name=company.name,
)
basic = Basic(user_key=access_key, identity=identity)
@@ -110,3 +132,13 @@ def authorize_impersonation(user, identity, service, action, call):
raise errors.unauthorized.InvalidCredentials("invalid user company")
return Payload(auth_type=None, identity=identity)
def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
"""
Compare hash for the passed secret key with the passed hash
:return: True if equal. Otherwise False
"""
return bcrypt.checkpw(
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
)

View File

@@ -32,6 +32,10 @@ class FixedUser:
def guest_enabled(cls):
return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False)
@classmethod
def pass_hashed(cls):
return config.get("apiserver.auth.fixed_users.pass_hashed", False)
@classmethod
def validate(cls):
if not cls.enabled():

View File

@@ -13,7 +13,12 @@ from .apicall import APICall
from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
from .util import parse_return_stack_on_code
from .validators import validate_all
from .validators import (
validate_data,
validate_auth,
validate_role,
validate_impersonation,
)
log = config.logger(__file__)
@@ -32,7 +37,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.12")
_max_version = PartialVersion("2.13")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -227,18 +232,6 @@ class ServiceRepo(object):
return True
return subcode in subcode_list
@classmethod
def _validate_call(cls, call: APICall) -> Optional[Endpoint]:
endpoint = cls._resolve_endpoint_from_call(call)
if call.failed:
return
validate_all(call, endpoint)
return endpoint
@classmethod
def validate_call(cls, call: APICall):
cls._validate_call(call)
@classmethod
def _get_company(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
@@ -252,7 +245,7 @@ class ServiceRepo(object):
return call.identity.company
@classmethod
def handle_call(cls, call: APICall):
def handle_call(cls, call: APICall, load_data_callback: Callable = None):
try:
if call.failed:
raise CallFailedError()
@@ -262,7 +255,18 @@ class ServiceRepo(object):
if call.failed:
raise CallFailedError()
validate_all(call, endpoint)
validate_auth(endpoint, call)
validate_role(endpoint, call)
if validate_impersonation(endpoint, call):
# if impersonating, validate role again
validate_role(endpoint, call)
if load_data_callback:
load_data_callback(call)
if call.failed:
raise CallFailedError()
validate_data(call, endpoint)
if call.failed:
raise CallFailedError()

View File

@@ -14,17 +14,9 @@ from .errors import CallParsingError
log = config.logger(__file__)
def validate_all(call: APICall, endpoint: Endpoint):
def validate_data(call: APICall, endpoint: Endpoint):
""" Perform all required call/endpoint validation, update call result appropriately """
try:
validate_auth(endpoint, call)
validate_role(endpoint, call)
if validate_impersonation(endpoint, call):
# if impersonating, validate role again
validate_role(endpoint, call)
# todo: remove vaildate_required_fields once all endpoints have json schema
validate_required_fields(endpoint, call)

View File

@@ -41,14 +41,12 @@ def login(call: APICall, *_, **__):
)
# Add authorization cookie
call.result.cookies[
config.get("apiserver.auth.session_auth_cookie_name")
] = call.result.data_model.token
call.result.set_auth_cookie(call.result.data_model.token)
@endpoint("auth.logout", min_version="2.2")
def logout(call: APICall, *_, **__):
call.result.cookies[config.get("apiserver.auth.session_auth_cookie_name")] = None
call.result.set_auth_cookie(None)
@endpoint(

View File

@@ -3,4 +3,4 @@ from apiserver.service_repo import APICall, endpoint
@endpoint("debug.ping")
def ping(call: APICall, _, __):
call.result.data = {"msg": "Because it trains cats and dogs"}
call.result.data = {"msg": "ClearML server"}

View File

@@ -594,10 +594,16 @@ def get_debug_images_v1_8(call, company_id, _):
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics}
task_metrics = defaultdict(set)
for tm in request.metrics:
task_metrics[tm.task].add(tm.metric)
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
tasks = task_bll.assert_exists(
company_id,
task_ids=task_ids,
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
)
@@ -610,7 +616,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
result = event_bll.debug_images_iterator.get_task_events(
company_id=next(iter(companies)),
metrics=[(m.task, m.metric) for m in request.metrics],
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
refresh=request.refresh,
@@ -622,13 +628,12 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
metrics=[
MetricEvents(
task=task,
metric=metric,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, metric, iterations) in result.metric_events
for (task, iterations) in result.metric_events
],
)

View File

@@ -6,12 +6,12 @@ from apiserver.apimodels.login import (
ServerErrors,
)
from apiserver.config import info
from apiserver.service_repo import endpoint
from apiserver.service_repo import endpoint, APICall
from apiserver.service_repo.auth.fixed_user import FixedUser
@endpoint("login.supported_modes", response_data_model=GetSupportedModesResponse)
def supported_modes(_, __, ___: GetSupportedModesRequest):
def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
guest_user = FixedUser.get_guest_user()
if guest_user:
guest = BasicGuestMode(
@@ -31,4 +31,10 @@ def supported_modes(_, __, ___: GetSupportedModesRequest):
missed_es_upgrade=info.missed_es_upgrade,
es_connection_error=info.es_connection_error,
),
authenticated=call.auth is not None,
)
@endpoint("login.logout", min_version="2.13")
def logout(call: APICall, _, __):
call.result.set_auth_cookie(None)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from functools import partial
from typing import Sequence
from mongoengine import Q, EmbeddedDocument
@@ -7,36 +8,55 @@ from apiserver import database
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest
from apiserver.apimodels.batch import BatchResponse, BatchRequest
from apiserver.apimodels.models import (
CreateModelRequest,
CreateModelResponse,
PublishModelRequest,
PublishModelResponse,
ModelTaskPublishResponse,
GetFrameworksRequest,
DeleteModelRequest,
DeleteMetadataRequest,
AddOrUpdateMetadataRequest,
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
)
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import validate_id
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.model.task.task import (
Task,
TaskStatus,
ModelItem,
TaskModelNames,
TaskModelTypes,
)
from apiserver.database.utils import (
parse_from_call,
get_company_or_none_constraint,
filter_fields,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import conform_tag_fields, conform_output_tags
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
validate_metadata,
get_metadata_from_api,
)
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
org_bll = OrgBLL()
model_bll = ModelBLL()
project_bll = ProjectBLL()
@@ -61,19 +81,20 @@ def get_by_id(call: APICall, company_id, _):
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call: APICall, company_id, _):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["output"], **query)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.output:
raise errors.bad_request.MissingTaskFields(field="output")
if not task.output.model:
raise errors.bad_request.MissingTaskFields(field="output.model")
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
model_id = task.output.model
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
@@ -86,10 +107,22 @@ def get_by_task_id(call: APICall, company_id, _):
call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
@@ -129,7 +162,7 @@ def get_all(call: APICall, company_id, _):
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
call.result.data = {
"frameworks": sorted(
model_bll.get_frameworks(company_id, project_ids=request.projects)
project_bll.get_model_frameworks(company_id, project_ids=request.projects)
)
}
@@ -147,12 +180,18 @@ create_fields = {
"design": None,
"labels": dict,
"ready": None,
"metadata": list,
}
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata")
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
metadata = fields.get("metadata")
if metadata:
validate_metadata(metadata)
return fields
@@ -174,6 +213,9 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
@@ -189,7 +231,7 @@ def update_for_task(call: APICall, company_id, _):
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["output", "execution", "name", "status", "project"],
_only=["models", "execution", "name", "status", "project"],
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
@@ -202,10 +244,9 @@ def update_for_task(call: APICall, company_id, _):
)
if override_model_id:
query = dict(company=company_id, id=override_model_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
@@ -214,12 +255,11 @@ def update_for_task(call: APICall, company_id, _):
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.output and task.output.model:
if task.models and task.models.output:
# model exists, update
res = _update_model(
call, company_id, model_id=task.output.model
).to_struct()
res.update({"id": task.output.model, "created": False})
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
@@ -227,14 +267,18 @@ def update_for_task(call: APICall, company_id, _):
fields = parse_model_fields(call, create_fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=datetime.utcnow(),
created=now,
last_update=now,
user=call.identity.user,
company=company_id,
project=task.project,
framework=task.execution.framework,
parent=task.execution.model,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
@@ -247,7 +291,13 @@ def update_for_task(call: APICall, company_id, _):
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
output__model=model.id,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)
call.result.data = {"id": model.id, "created": True}
@@ -277,12 +327,16 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
validate_metadata(fields.get("metadata"))
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=datetime.utcnow(),
created=now,
last_update=now,
**fields,
)
model.save()
@@ -335,10 +389,9 @@ def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
@@ -363,6 +416,9 @@ def edit(call: APICall, company_id, _):
)
if fields:
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
@@ -384,11 +440,9 @@ def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
# get model by id
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
data = prepare_update_fields(call, company_id, call.data)
@@ -399,8 +453,15 @@ def _update_model(call: APICall, company_id, model_id=None):
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
metadata = data.get("metadata")
if metadata:
validate_metadata(metadata)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
@@ -424,84 +485,130 @@ def update(call, company_id, _):
request_data_model=PublishModelRequest,
response_data_model=PublishModelResponse,
)
def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
updated, published_task_data = TaskBLL.model_set_ready(
model_id=req_model.model,
def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
publish_task=req_model.publish_task,
force_publish_task=req_model.force_publish_task,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
)
call.result.data_model = PublishModelResponse(
updated=updated,
published_task=ModelTaskPublishResponse(**published_task_data)
if published_task_data
else None,
updated=updated, published_task=published_task
)
@endpoint("models.delete", required_fields=["model"])
def update(call: APICall, company_id, _):
model_id = call.data["model"]
force = call.data.get("force", False)
@endpoint(
"models.publish_many",
request_data_model=ModelsPublishManyRequest,
response_data_model=BatchResponse,
)
def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
results, failures = run_batch_operation(
func=partial(
ModelBLL.publish_model,
company_id=company_id,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),
ids=request.ids,
)
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).only("id", "task", "project").first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
deleted_model_id = f"__DELETED__{model_id}"
using_tasks = Task.objects(execution__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
"as execution model, use force=True to delete",
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
using_tasks.update(
execution__model=deleted_model_id, upsert=False, multi=True
call.result.data_model = BatchResponse(
succeeded=[
dict(
id=_id,
updated=bool(updated),
published_task=published_task.to_struct() if published_task else None,
)
for _id, (updated, published_task) in results
],
failed=failures,
)
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
now = datetime.utcnow()
task.update(
output__model=deleted_model_id,
output__error=f"model deleted on {now.isoformat()}",
last_change=now,
upsert=False,
)
del_count = Model.objects(**query).delete()
if del_count:
_reset_cached_tags(company_id, projects=[model.project])
call.result.data = dict(deleted=del_count > 0)
@endpoint("models.delete", request_data_model=DeleteModelRequest)
def delete(call: APICall, company_id, request: DeleteModelRequest):
del_count, model = ModelBLL.delete_model(
model_id=request.model, company_id=company_id, force=request.force
)
if del_count:
_reset_cached_tags(
company_id, projects=[model.project] if model.project else []
)
call.result.data = dict(deleted=bool(del_count), url=model.uri)
@endpoint(
"models.delete_many",
request_data_model=ModelsDeleteManyRequest,
response_data_model=BatchResponse,
)
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force),
ids=request.ids,
)
if results:
projects = set(model.project for _, (_, model) in results)
_reset_cached_tags(company_id, projects=list(projects))
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, deleted=bool(deleted), url=model.uri)
for _id, (deleted, model) in results
],
failed=failures,
)
@endpoint(
"models.archive_many",
request_data_model=BatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
failed=failures,
)
@endpoint(
"models.unarchive_many",
request_data_model=BatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
],
failed=failures,
)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
)
call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
)
@endpoint(
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
)
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
)
@endpoint("models.move", request_data_model=MoveRequest)
@@ -511,14 +618,43 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required"
)
with translate_errors_context():
return {
"project_id": project_bll.move_under_project(
entity_cls=Model,
user=call.identity.user,
company=company_id,
ids=request.ids,
project=request.project,
project_name=request.project_name,
)
}
return {
"project_id": project_bll.move_under_project(
entity_cls=Model,
user=call.identity.user,
company=company_id,
ids=request.ids,
project=request.project,
project_name=request.project_name,
)
}
@endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
updated = metadata_add_or_update(
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}

View File

@@ -1,31 +1,30 @@
from collections import defaultdict
from datetime import datetime
from itertools import groupby
from operator import itemgetter
from typing import Sequence
import dpath
import attr
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
GetHyperParamReq,
ProjectReq,
GetHyperParamRequest,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
ProjectsGetRequest,
DeleteRequest,
MoveRequest,
MergeRequest,
ProjectOrNoneRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project.project_cleanup import delete_project
from apiserver.bll.task import TaskBLL
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.utils import (
parse_from_call,
get_options,
get_company_or_none_constraint,
)
from apiserver.service_repo import APICall, endpoint
@@ -39,7 +38,7 @@ from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
task_bll = TaskBLL()
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
project_bll = ProjectBLL()
create_fields = {
"name": None,
@@ -49,10 +48,6 @@ create_fields = {
"default_output_destination": None,
}
get_all_query_options = Project.QueryParameterOptions(
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
)
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
@@ -74,210 +69,88 @@ def get_by_id(call):
call.result.data = {"project": project_dict}
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
archived = EntityVisibility.archived.value
def _adjust_search_parameters(data: dict, shallow_search: bool):
"""
1. Make sure that there is no external query on path
2. If not shallow_search and parent is provided then parent can be at any place in path
3. If shallow_search and no parent provided then use a top level parent
"""
data.pop("path", None)
if not shallow_search:
if "parent" in data:
data["path"] = data.pop("parent")
return
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing"]},
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
if "parent" not in data:
data["parent"] = [None]
@endpoint("projects.get_all_ex")
def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
allow_public = not call.data.get("non_public", False)
if stats_for_state:
try:
specific_state = EntityVisibility(stats_for_state)
except ValueError:
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
else:
specific_state = None
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
projects = Project.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=allow_public,
)
conform_output_tags(call, projects)
allow_public = not request.non_public
data = call.data
requested_ids = data.get("id")
with TimingContext("mongo", "projects_get_all"):
data = call.data
if request.active_users:
ids = project_bll.get_projects_with_active_user(
company=company_id,
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
)
if not ids:
call.result.data = {"projects": []}
return
data["id"] = ids
if not include_stats:
_adjust_search_parameters(data, shallow_search=request.shallow_search)
projects = Project.get_many_with_join(
company=company_id, query_dict=data, allow_public=allow_public,
)
if request.check_own_contents and requested_ids:
existing_requested_ids = {
project["id"] for project in projects if project["id"] in requested_ids
}
if existing_requested_ids:
contents = project_bll.calc_own_contents(
company=company_id, project_ids=list(existing_requested_ids)
)
for project in projects:
project.update(**contents.get(project["id"], {}))
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects}
return
ids = [project["id"] for project in projects]
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
call.identity.company, ids, specific_state=specific_state
project_ids = {project["id"] for project in projects}
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def safe_get(obj, path, default=None):
try:
return dpath.get(obj, path)
except KeyError:
return default
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
for project in projects:
project["stats"] = {
task_state.value: get_status_counts(project["id"], task_state.value)
for task_state in report_for_states
}
call.result.data = {"projects": projects}
call.result.data = {"projects": projects}
@endpoint("projects.get_all")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
data = call.data
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
projects = Project.get_many(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
parameters=call.data,
query_dict=data,
parameters=data,
allow_public=True,
)
conform_output_tags(call, projects)
@@ -286,9 +159,7 @@ def get_all(call: APICall):
@endpoint(
"projects.create",
required_fields=["name", "description"],
response_data_model=IdResponse,
"projects.create", required_fields=["name"], response_data_model=IdResponse,
)
def create(call: APICall):
identity = call.identity
@@ -316,61 +187,72 @@ def update(call: APICall):
:return: updated - `int` - number of projects updated
fields - `[string]` - updated fields
"""
project_id = call.data["project"]
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields, validate=True)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields, validate=True)
updated = ProjectBLL.update(
company=call.identity.company, project_id=call.data["project"], **fields
)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("projects.delete", required_fields=["project"])
def delete(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
force = call.data.get("force", False)
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
# have the correct project ID.
# Find the tasks which belong to the project
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
res = cls.objects(
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
).only("id")
if res and not force:
raise error("use force=true to delete", id=project_id)
updated_count = res.update(project=None)
project.delete()
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(company, Tags.Task, projects=projects)
org_bll.reset_tags(company, Tags.Model, projects=projects)
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
@endpoint("projects.move", request_data_model=MoveRequest)
def move(call: APICall, company: str, request: MoveRequest):
moved, affected_projects = ProjectBLL.move_project(
company=company,
user=call.identity.user,
project_id=request.project,
new_location=request.new_location,
)
_reset_cached_tags(company, projects=list(affected_projects))
call.result.data = {"moved": moved}
@endpoint("projects.merge", request_data_model=MergeRequest)
def merge(call: APICall, company: str, request: MergeRequest):
moved_entitites, moved_projects, affected_projects = ProjectBLL.merge_project(
company, source_id=request.project, destination_id=request.destination_project
)
_reset_cached_tags(company, projects=list(affected_projects))
call.result.data = {
"moved_entities": moved_entitites,
"moved_projects": moved_projects,
}
@endpoint("projects.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id: str, request: DeleteRequest):
res, affected_projects = delete_project(
company=company_id,
project_id=request.project,
force=request.force,
delete_contents=request.delete_contents,
)
_reset_cached_tags(company_id, projects=list(affected_projects))
call.result.data = {**attr.asdict(res)}
@endpoint(
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
)
def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
):
metrics = task_bll.get_unique_metric_variants(
company_id, [request.project] if request.project else None
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
)
call.result.data = {"metrics": metrics}
@@ -379,13 +261,14 @@ def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectR
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",
request_data_model=GetHyperParamReq,
request_data_model=GetHyperParamRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
@@ -397,6 +280,28 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
}
@endpoint(
"projects.get_hyperparam_values",
min_version="2.13",
request_data_model=ProjectHyperparamValuesRequest,
)
def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
total, values = task_bll.get_hyperparam_distinct_values(
company_id,
project_ids=request.projects,
section=request.section,
name=request.name,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
)
call.result.data = {
"total": total,
"values": values,
}
@endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
@@ -452,7 +357,10 @@ def get_task_parents(
call: APICall, company_id: str, request: ProjectTaskParentsRequest
):
call.result.data = {
"parents": org_bll.get_parent_tasks(
company_id, projects=request.projects, state=request.tasks_state
"parents": project_bll.get_task_parents(
company_id,
projects=request.projects,
include_subprojects=request.include_subprojects,
state=request.tasks_state,
)
}

View File

@@ -11,12 +11,21 @@ from apiserver.apimodels.queues import (
GetMetricsRequest,
GetMetricsResponse,
QueueMetrics,
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
)
from apiserver.bll.queue import QueueBLL
from apiserver.bll.util import extract_properties_to_lists
from apiserver.bll.workers import WorkerBLL
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.queue import Queue
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import conform_tag_fields, conform_output_tags, conform_tags
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
conform_tags,
get_metadata_from_api,
)
from apiserver.utilities import extract_properties_to_lists
worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll)
@@ -62,7 +71,11 @@ def create(call: APICall, company_id, request: CreateRequest):
call, request.tags, request.system_tags, validate=True
)
queue = queue_bll.create(
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
company_id=company_id,
name=request.name,
tags=tags,
system_tags=system_tags,
metadata=get_metadata_from_api(request.metadata),
)
call.result.data = {"id": queue.id}
@@ -220,3 +233,25 @@ def get_queue_metrics(
for queue, data in queue_dicts.items()
]
)
@endpoint("queues.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {
"updated": metadata_add_or_update(
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
)
}
@endpoint("queues.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}

View File

@@ -1,16 +1,15 @@
from copy import deepcopy
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
from functools import partial
from typing import Sequence, Union, Tuple
import attr
import dpath
import mongoengine
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
from apiserver.apierrors import errors, APIError
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidTaskId
from apiserver.apimodels.base import (
UpdateResponse,
@@ -18,11 +17,15 @@ from apiserver.apimodels.base import (
MakePublicRequest,
MoveRequest,
)
from apiserver.apimodels.batch import (
BatchResponse,
UpdateBatchResponse,
UpdateBatchItem,
)
from apiserver.apimodels.tasks import (
StartedResponse,
ResetResponse,
PublishRequest,
PublishResponse,
CreateRequest,
UpdateRequest,
SetRequirementsRequest,
@@ -46,16 +49,30 @@ from apiserver.apimodels.tasks import (
DeleteArtifactsRequest,
ArchiveResponse,
ArchiveRequest,
AddUpdateModelRequest,
DeleteModelsRequest,
StopManyRequest,
EnqueueManyRequest,
ResetManyRequest,
DeleteManyRequest,
PublishManyRequest,
TaskBatchRequest,
EnqueueManyResponse,
EnqueueBatchItem,
DequeueBatchItem,
DequeueManyResponse,
ResetManyResponse,
ResetBatchItem,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
split_by,
)
from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save,
@@ -69,25 +86,36 @@ from apiserver.bll.task.param_utils import (
params_unprepare_from_saved,
escape_paths,
)
from apiserver.bll.task.utils import update_task
from apiserver.bll.util import SetFieldsResolver
from apiserver.bll.task.task_operations import (
stop_task,
enqueue_task,
dequeue_task,
reset_task,
archive_task,
delete_task,
publish_task,
unarchive_task,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
TaskStatus,
Script,
DEFAULT_LAST_ITERATION,
Execution,
ArtifactModes,
ModelItem,
TaskModelTypes,
)
from apiserver.database.utils import get_fields_attr, parse_from_call
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
DockerCmdBackwardsCompatibility,
escape_dict_field,
unescape_dict_field,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.partial_version import PartialVersion
@@ -155,28 +183,48 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall):
projection = Task.get_projection(call.data)
if projection:
Task.set_projection(call.data, escape_paths(projection))
def escape_execution_parameters(call: APICall) -> dict:
if not call.data:
return call.data
ordering = Task.get_ordering(call.data)
keys = list(call.data)
call_data = {
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
}
projection = Task.get_projection(call_data)
if projection:
Task.set_projection(call_data, escape_paths(projection))
ordering = Task.get_ordering(call_data)
if ordering:
Task.set_ordering(call.data, escape_paths(ordering))
Task.set_ordering(call_data, escape_paths(ordering))
return call_data
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True, # required in case projection is requested for public dataset/versions
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@@ -186,12 +234,12 @@ def get_all_ex(call: APICall, company_id, _):
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True,
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
@@ -202,15 +250,15 @@ def get_by_id_ex(call: APICall, company_id, _):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True, # required in case projection is requested for public dataset/versions
parameters=call_data,
query_dict=call_data,
allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@@ -219,7 +267,9 @@ def get_all(call: APICall, company_id, _):
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
def get_types(call: APICall, company_id, request: GetTypesRequest):
call.result.data = {
"types": list(task_bll.get_types(company_id, project_ids=request.projects))
"types": list(
project_bll.get_task_types(company_id, project_ids=request.projects)
)
}
@@ -236,7 +286,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
"""
call.result.data_model = UpdateResponse(
**TaskBLL.stop_task(
**stop_task(
task_id=req_model.task,
company_id=company_id,
user_name=call.identity.user_name,
@@ -246,6 +296,28 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
)
@endpoint(
"tasks.stop_many",
request_data_model=StopManyRequest,
response_data_model=UpdateBatchResponse,
)
def stop_many(call: APICall, company_id, request: StopManyRequest):
results, failures = run_batch_operation(
func=partial(
stop_task,
company_id=company_id,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
),
ids=request.ids,
)
call.result.data_model = UpdateBatchResponse(
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
failed=failures,
)
@endpoint(
"tasks.stopped",
request_data_model=UpdateRequest,
@@ -308,18 +380,27 @@ create_fields = {
"parent": Task,
"project": None,
"input": None,
"models": None,
"container": None,
"output_dest": None,
"execution": None,
"hyperparams": None,
"configuration": None,
"script": None,
"runtime": None,
}
dict_fields_paths = [("execution", "model_labels"), "container"]
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
conform_tag_fields(call, fields, validate=True)
params_prepare_for_save(fields, previous_task=previous_task)
artifacts_prepare_for_save(fields)
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
DockerCmdBackwardsCompatibility.prepare_for_save(call, fields)
for path in dict_fields_paths:
escape_dict_field(fields, path)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_stripped_fields:
@@ -342,7 +423,15 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
conform_output_tags(call, tasks_data)
for data in tasks_data:
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
for path in dict_fields_paths:
unescape_dict_field(data, path)
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
for data in tasks_data:
params_unprepare_from_saved(
fields=data, copy_to_legacy=need_legacy_params,
)
@@ -368,6 +457,17 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
# Add models updated time
models = fields.get("models")
if models:
now = datetime.utcnow()
for field in (TaskModelTypes.input, TaskModelTypes.output):
field_models = models.get(field)
if not field_models:
continue
for model in field_models:
model["updated"] = now
return prepare_for_save(call, fields, previous_task=previous_task)
@@ -386,6 +486,9 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
@endpoint("tasks.validate", request_data_model=CreateRequest)
def validate(call: APICall, company_id, req_model: CreateRequest):
parent = call.data.get("parent")
if parent and parent.startswith(deleted_prefix):
call.data.pop("parent")
_validate_and_get_task_from_call(call)
@@ -431,7 +534,9 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
system_tags=request.new_task_system_tags,
hyperparams=request.new_task_hyperparams,
configuration=request.new_task_configuration,
container=request.new_task_container,
execution_overrides=request.execution_overrides,
input_models=request.new_task_input_models,
validate_references=request.validate_references,
new_project_name=request.new_project_name,
)
@@ -698,7 +803,7 @@ def get_configuration_names(
):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
call.result.data = {
@@ -742,61 +847,41 @@ def delete_configuration(
request_data_model=EnqueueRequest,
response_data_model=EnqueueResponse,
)
def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
task_id = req_model.task
queue_id = req_model.queue
status_message = req_model.status_message
status_reason = req_model.status_reason
def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
)
call.result.data_model = EnqueueResponse(queued=queued, **res)
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
_only=("type", "script", "execution", "status", "project", "id"), **query
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
res = EnqueueResponse(
**ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
).execute()
)
try:
queue_bll.add_task(
company_id=company_id, queue_id=queue_id, task_id=task.id
)
except Exception:
# failed enqueueing, revert to previous state
ChangeStatusRequest(
task=task,
current_status_override=TaskStatus.queued,
new_status=task.status,
force=True,
status_reason="failed enqueueing",
).execute()
raise
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(
execution=Execution(queue=queue_id), multi=False
)
res.queued = 1
res.fields.update(**{"execution.queue": queue_id})
call.result.data_model = res
@endpoint(
"tasks.enqueue_many",
request_data_model=EnqueueManyRequest,
response_data_model=EnqueueManyResponse,
)
def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
results, failures = run_batch_operation(
func=partial(
enqueue_task,
company_id=company_id,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
validate=request.validate_tasks,
),
ids=request.ids,
)
call.result.data_model = EnqueueManyResponse(
succeeded=[
EnqueueBatchItem(id=_id, queued=bool(queued), **res)
for _id, (queued, res) in results
],
failed=failures,
)
@endpoint(
@@ -805,333 +890,250 @@ def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, request: UpdateRequest):
task = TaskBLL.get_task_with_access(
request.task,
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
only=("id", "execution", "status", "project"),
requires_write_access=True,
status_message=request.status_message,
status_reason=request.status_reason,
)
res = DequeueResponse(
**TaskBLL.dequeue_and_change_status(
task,
company_id,
call.result.data_model = DequeueResponse(dequeued=dequeued, **res)
@endpoint(
"tasks.dequeue_many",
request_data_model=TaskBatchRequest,
response_data_model=DequeueManyResponse,
)
def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
results, failures = run_batch_operation(
func=partial(
dequeue_task,
company_id=company_id,
status_message=request.status_message,
status_reason=request.status_reason,
)
),
ids=request.ids,
)
call.result.data_model = DequeueManyResponse(
succeeded=[
DequeueBatchItem(id=_id, dequeued=bool(dequeued), **res)
for _id, (dequeued, res) in results
],
failed=failures,
)
res.dequeued = 1
call.result.data_model = res
@endpoint(
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
)
def reset(call: APICall, company_id, request: ResetRequest):
task = TaskBLL.get_task_with_access(
request.task, company_id=company_id, requires_write_access=True
)
force = request.force
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
api_results = {}
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
else:
if dequeued:
api_results.update(dequeued=dequeued)
cleaned_up = cleanup_task(task, force)
api_results.update(attr.asdict(cleaned_up))
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
unset__output__result=1,
unset__output__model=1,
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if request.clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
if task.execution and task.execution.artifacts:
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts.items()
if artifact.mode == ArtifactModes.input
}
)
res = ResetResponse(
**ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
force=force,
status_reason="reset",
status_message="reset",
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
**updates,
)
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
company_id=company_id,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
)
res = ResetResponse(**updates, dequeued=dequeued)
# do not return artifacts since they are not serializable
res.fields.pop("execution.artifacts", None)
for key, value in api_results.items():
for key, value in attr.asdict(cleanup_res).items():
setattr(res, key, value)
call.result.data_model = res
@endpoint(
"tasks.reset_many",
request_data_model=ResetManyRequest,
response_data_model=ResetManyResponse,
)
def reset_many(call: APICall, company_id, request: ResetManyRequest):
results, failures = run_batch_operation(
func=partial(
reset_task,
company_id=company_id,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
),
ids=request.ids,
)
def clean_res(res: dict) -> dict:
# do not return artifacts since they are not serializable
fields = res.get("fields")
if fields:
fields.pop("execution.artifacts", None)
return res
call.result.data_model = ResetManyResponse(
succeeded=[
ResetBatchItem(
id=_id,
dequeued=bool(dequeued.get("removed")) if dequeued else False,
**attr.asdict(cleanup),
**clean_res(res),
)
for _id, (dequeued, cleanup, res) in results
],
failed=failures,
)
@endpoint(
"tasks.archive",
request_data_model=ArchiveRequest,
response_data_model=ArchiveResponse,
)
def archive(call: APICall, company_id, request: ArchiveRequest):
archived = 0
tasks = TaskBLL.assert_exists(
company_id,
task_ids=request.tasks,
only=("id", "execution", "status", "project", "system_tags"),
only=("id", "execution", "status", "project", "system_tags", "enqueue_status"),
)
archived = 0
for task in tasks:
try:
TaskBLL.dequeue_and_change_status(
task, company_id, request.status_message, request.status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
task.update(
archived += archive_task(
company_id=company_id,
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
system_tags=sorted(
set(task.system_tags) | {EntityVisibility.archived.value}
),
last_change=datetime.utcnow(),
)
archived += 1
call.result.data_model = ArchiveResponse(archived=archived)
class DocumentGroup(list):
"""
Operate on a list of documents as if they were a query result
"""
def __init__(self, document_type, documents):
super(DocumentGroup, self).__init__(documents)
self.type = document_type
def objects(self, *args, **kwargs):
return self.type.objects(id__in=[obj.id for obj in self], *args, **kwargs)
T = TypeVar("T")
class TaskOutputs(object):
"""
Split task outputs of the same type by the ready state
"""
published = None # type: DocumentGroup
draft = None # type: DocumentGroup
def __init__(self, is_published, document_type, children):
# type: (Callable[[T], bool], Type[mongoengine.Document], Sequence[T]) -> ()
"""
:param is_published: predicate returning whether items is considered published
:param document_type: type of output
:param children: output documents
"""
self.published, self.draft = map(
lambda x: DocumentGroup(document_type, x), split_by(is_published, children)
)
@attr.s
class CleanupResult(object):
"""
Counts of objects modified in task cleanup operation
"""
updated_children = attr.ib(type=int)
updated_models = attr.ib(type=int)
deleted_models = attr.ib(type=int)
def cleanup_task(task: Task, force: bool = False):
"""
Validate task deletion and delete/modify all its output.
:param task: task object
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
models, child_tasks = get_outputs_for_deletion(task, force)
deleted_task_id = trash_task_id(task.id)
if child_tasks:
with TimingContext("mongo", "update_task_children"):
updated_children = child_tasks.update(parent=deleted_task_id)
else:
updated_children = 0
if models.draft:
with TimingContext("mongo", "delete_models"):
deleted_models = models.draft.objects().delete()
else:
deleted_models = 0
if models.published:
with TimingContext("mongo", "update_task_models"):
updated_models = models.published.objects().update(task=deleted_task_id)
else:
updated_models = 0
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
@endpoint(
"tasks.archive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
results, failures = run_batch_operation(
func=partial(
archive_task,
company_id=company_id,
status_message=request.status_message,
status_reason=request.status_reason,
),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
failed=failures,
)
def get_outputs_for_deletion(task, force=False):
with TimingContext("mongo", "get_task_models"):
models = TaskOutputs(
attrgetter("ready"),
Model,
Model.objects(task=task.id).only("id", "task", "ready"),
)
if not force and models.published:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(models.published),
)
if task.output.model:
output_model = get_output_model(task, force)
if output_model:
if output_model.ready:
models.published.append(output_model)
else:
models.draft.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = [m.id for m in models.draft]
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
)
busy_models = [t.execution.model for t in dependent_tasks]
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
with TimingContext("mongo", "get_task_children"):
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
published_tasks = [
task for task in tasks if task.status == TaskStatus.published
]
if not force and published_tasks:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True", task=task.id, children=published_tasks
)
return models, tasks
def get_output_model(task, force=False):
with TimingContext("mongo", "get_task_output_model"):
output_model = Model.objects(id=task.output.model).first()
if output_model and output_model.ready and not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True", task=task.id, model=task.output.model
)
return output_model
def trash_task_id(task_id):
return "__DELETED__{}".format(task_id)
@endpoint(
"tasks.unarchive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
results, failures = run_batch_operation(
func=partial(
unarchive_task,
company_id=company_id,
status_message=request.status_message,
status_reason=request.status_reason,
),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
],
failed=failures,
)
@endpoint("tasks.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, requires_write_access=True
def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(
task_id=request.task,
company_id=company_id,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
)
if deleted:
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
def delete_many(call: APICall, company_id, request: DeleteManyRequest):
results, failures = run_batch_operation(
func=partial(
delete_task,
company_id=company_id,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
),
ids=request.ids,
)
move_to_trash = req_model.move_to_trash
force = req_model.force
if results:
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))
if task.status != TaskStatus.created and not force:
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
with translate_errors_context():
result = cleanup_task(task, force)
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
with TimingContext("mongo", "save_task"):
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
task.delete()
_reset_cached_tags(company_id, projects=[task.project])
update_project_time(task.project)
call.result.data = dict(deleted=True, **attr.asdict(result))
call.result.data = dict(
succeeded=[
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
for _id, (deleted, _, cleanup_res) in results
],
failed=failures,
)
@endpoint(
"tasks.publish",
request_data_model=PublishRequest,
response_data_model=PublishResponse,
response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, req_model: PublishRequest):
call.result.data_model = PublishResponse(
**TaskBLL.publish_task(
task_id=req_model.task,
def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
task_id=request.task,
company_id=company_id,
force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason,
status_message=request.status_message,
)
call.result.data_model = UpdateResponse(**updates)
@endpoint(
"tasks.publish_many",
request_data_model=PublishManyRequest,
response_data_model=UpdateBatchResponse,
)
def publish_many(call: APICall, company_id, request: PublishManyRequest):
results, failures = run_batch_operation(
func=partial(
publish_task,
company_id=company_id,
publish_model=req_model.publish_model,
force=req_model.force,
status_reason=req_model.status_reason,
status_message=req_model.status_message,
)
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
else None,
status_reason=request.status_reason,
status_message=request.status_message,
),
ids=request.ids,
)
call.result.data_model = UpdateBatchResponse(
succeeded=[UpdateBatchItem(id=_id, **res) for _id, res in results],
failed=failures,
)
@@ -1235,3 +1237,40 @@ def move(call: APICall, company_id: str, request: MoveRequest):
update_project_time(projects)
return {"project_id": project_id}
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
models_field = f"models__{request.type}"
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
query = {"id": request.task, f"{models_field}__name": request.name}
updated = Task.objects(**query).update_one(**{f"set__{models_field}__S": model})
updated = TaskBLL.update_statistics(
task_id=request.task,
company_id=company_id,
last_iteration_max=request.iteration,
**({f"push__{models_field}": model} if not updated else {}),
)
return {"updated": updated}
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
delete_names = {
type_: [m.name for m in request.models if m.type == type_]
for type_ in get_options(TaskModelTypes)
}
commands = {
f"pull__models__{field}__name__in": names
for field, names in delete_names.items()
if names
}
updated = task.update(last_change=datetime.utcnow(), **commands,)
return {"updated": updated}

View File

@@ -1,10 +1,15 @@
from datetime import datetime
from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.database.utils import partition_tags
from apiserver.service_repo import APICall
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from apiserver.utilities.partial_version import PartialVersion
@@ -25,16 +30,23 @@ def get_tags_response(ret: dict) -> dict:
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
Make sure that tags are always returned sorted
For old clients both tags and system tags are returned in 'tags' field
"""
if call.requested_endpoint_version >= PartialVersion("2.3"):
return
if isinstance(documents, dict):
documents = [documents]
merge_tags = call.requested_endpoint_version < PartialVersion("2.3")
for doc in documents:
system_tags = doc.get("system_tags")
if system_tags:
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
if merge_tags:
system_tags = doc.get("system_tags")
if system_tags:
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
for field in ("system_tags", "tags"):
tags = doc.get(field)
if tags:
doc[field] = sorted(tags)
def conform_tag_fields(call: APICall, document: dict, validate=False):
@@ -84,3 +96,148 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)
def escape_dict(data: dict) -> dict:
if not data:
return data
return {ParameterKeyEscaper.escape(k): v for k, v in data.items()}
def unescape_dict(data: dict) -> dict:
if not data:
return data
return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()}
def escape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
if isinstance(path, str):
path = (path,)
data = nested_get(fields, path)
if not data or not isinstance(data, dict):
return
nested_set(fields, path, escape_dict(data))
def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
if isinstance(path, str):
path = (path,)
data = nested_get(fields, path)
if not data or not isinstance(data, dict):
return
nested_set(fields, path, unescape_dict(data))
class ModelsBackwardsCompatibility:
max_version = PartialVersion("2.13")
mode_to_fields = {
TaskModelTypes.input: ("execution", "model"),
TaskModelTypes.output: ("output", "model"),
}
models_field = "models"
@classmethod
def prepare_for_save(cls, call: APICall, fields: dict):
if call.requested_endpoint_version >= cls.max_version:
return
for mode, field in cls.mode_to_fields.items():
value = nested_get(fields, field)
if value is None:
continue
val = [
dict(
name=TaskModelNames[mode],
model=value,
updated=datetime.utcnow(),
)
] if value else []
nested_set(fields, (cls.models_field, mode), value=val)
nested_delete(fields, field)
@classmethod
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
if call.requested_endpoint_version >= cls.max_version:
return
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
for task in tasks_data:
for mode, field in cls.mode_to_fields.items():
models = nested_get(task, (cls.models_field, mode))
if not models:
continue
model = models[0] if mode == TaskModelTypes.input else models[-1]
if model:
nested_set(task, field, model.get("model"))
class DockerCmdBackwardsCompatibility:
max_version = PartialVersion("2.13")
field = ("execution", "docker_cmd")
@classmethod
def prepare_for_save(cls, call: APICall, fields: dict):
if call.requested_endpoint_version >= cls.max_version:
return
docker_cmd = nested_get(fields, cls.field)
if docker_cmd is not None:
image, _, arguments = docker_cmd.partition(" ")
nested_set(fields, ("container", "image"), value=image)
nested_set(fields, ("container", "arguments"), value=arguments)
nested_delete(fields, cls.field)
@classmethod
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
if call.requested_endpoint_version >= cls.max_version:
return
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
for task in tasks_data:
container = task.get("container")
if not container or not container.get("image"):
continue
docker_cmd = " ".join(
filter(None, map(container.get, ("image", "arguments")))
)
if docker_cmd:
nested_set(task, cls.field, docker_cmd)
def validate_metadata(metadata: Sequence[dict]):
if not metadata:
return
keys = [m.get("key") for m in metadata]
unique_keys = set(keys)
unique_keys.discard(None)
if len(keys) != len(set(keys)):
raise errors.bad_request.ValidationError("Metadata keys should be unique")
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
if not api_metadata:
return api_metadata
metadata = [m.to_struct() for m in api_metadata]
validate_metadata(metadata)
return metadata

View File

@@ -23,10 +23,10 @@ from apiserver.apimodels.workers import (
GetActivityReportResponse,
ActivityReportSeries,
)
from apiserver.bll.util import extract_properties_to_lists
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import extract_properties_to_lists
log = config.logger(__file__)

View File

@@ -1,3 +1,4 @@
import os
import time
from contextlib import contextmanager
from time import sleep
@@ -18,10 +19,12 @@ def distributed_lock(name: str, timeout: int, max_wait: int = 0):
lock_name = f"dist_lock_{name}"
start = time.time()
max_wait = max_wait or timeout * 2
while not _redis.set(lock_name, value="", ex=timeout, nx=True):
pid = os.getpid()
while _redis.set(lock_name, value=pid, ex=timeout, nx=True) is None:
sleep(1)
if time.time() - start > max_wait:
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds")
holder = _redis.get(lock_name)
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds. The lock is hold by {holder}")
try:
yield
finally:

View File

@@ -1,5 +1,4 @@
import json
import logging
import os
import time
from contextlib import contextmanager
@@ -10,16 +9,14 @@ import requests
import six
from boltons.iterutils import remap
from boltons.typeutils import issubclass
from pyhocon import ConfigFactory
from requests.adapters import HTTPAdapter
from requests.auth import HTTPBasicAuth
from requests.packages.urllib3.util.retry import Retry
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
config = ConfigFactory.parse_file("api_client.conf")
log = logging.getLogger("api_client")
log = config.logger(__file__)
class APICallResult:
@@ -111,7 +108,7 @@ class APIClient:
self.api_key = (
api_key
or os.environ.get("SM_API_KEY")
or config.get("api_key")
or config.get("apiclient.api_key")
)
if not self.api_key:
raise ValueError("APIClient requires api_key in constructor or config")
@@ -119,7 +116,7 @@ class APIClient:
self.secret_key = (
secret_key
or os.environ.get("SM_API_SECRET")
or config.get("secret_key")
or config.get("apiclient.secret_key")
)
if not self.secret_key:
raise ValueError(
@@ -127,7 +124,7 @@ class APIClient:
)
self.base_url = (
base_url or os.environ.get("SM_API_URL") or config.get("base_url")
base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")
)
if not self.base_url:
raise ValueError("APIClient requires base_url in constructor or config")
@@ -139,9 +136,9 @@ class APIClient:
# create http session
self.http_session = requests.session()
retries = config.get("retries", 7)
backoff_factor = config.get("backoff_factor", 0.3)
status_forcelist = config.get("status_forcelist", (500, 502, 504))
retries = config.get("apiclient.retries", 7)
backoff_factor = config.get("apiclient.backoff_factor", 0.3)
status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504))
retry = Retry(
total=retries,
read=retries,
@@ -154,7 +151,7 @@ class APIClient:
self.http_session.mount("https://", adapter)
if impersonated_user_id:
self.http_session.headers["X-Trains-Impersonate-As"] = impersonated_user_id
self.http_session.headers["X-ClearML-Impersonate-As"] = impersonated_user_id
if not self.session_token:
self.login()
@@ -211,7 +208,7 @@ class APIClient:
headers = {"Content-Type": "application/json"}
headers.update(headers_overrides)
if is_async:
headers["X-Trains-Async"] = "1"
headers["X-ClearML-Async"] = "1"
if not isinstance(data, six.string_types):
data = json.dumps(data)
@@ -241,7 +238,7 @@ class APIClient:
call_id = res.meta.call_id
async_res_url = "%s/async.result?id=%s" % (self.base_url, call_id)
async_res_headers = headers.copy()
async_res_headers.pop("X-Trains-Async")
async_res_headers.pop("X-ClearML-Async")
while not got_result:
log.info("Got 202. Checking async result for %s (%s)" % (url, call_id))
http_res = self.http_session.get(

View File

@@ -5,6 +5,8 @@ from functools import partial
from typing import Iterable
from unittest import TestCase
from packaging.version import parse
from apiserver.tests.api_client import APIClient
from apiserver.config_repo import config
@@ -69,19 +71,10 @@ class TestService(TestCase, TestServiceInterface):
delete_params=delete_params,
)
def create_temp_version(self, *, client=None, **kwargs) -> str:
return self._create_temp_helper(
service="datasets",
create_endpoint="create_version",
delete_endpoint="delete_version",
object_name="version",
create_params=kwargs,
client=client,
)
def setUp(self, version="1.7"):
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
self._deferred = []
self._version = parse(version)
header(self.id())
def tearDown(self):

View File

@@ -1 +0,0 @@
parameterized

View File

@@ -0,0 +1,141 @@
from apiserver.database.utils import id as db_id
from apiserver.tests.automated import TestService
class TestBatchOperations(TestService):
name = "batch operation test"
comment = "this is a comment"
delete_params = dict(can_fail=True, force=True)
def setUp(self, version="2.13"):
super().setUp(version=version)
def test_tasks(self):
tasks = [self._temp_task() for _ in range(2)]
models = [
self._temp_task_model(task=t, uri=f"uri_{idx}")
for idx, t in enumerate(tasks)
]
missing_id = db_id()
ids = [*tasks, missing_id]
# enqueue
res = self.api.tasks.enqueue_many(ids=ids)
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"queued"})
# stop
for t in tasks:
self.api.tasks.started(task=t)
res = self.api.tasks.stop_many(ids=ids)
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"stopped"})
# publish
res = self.api.tasks.publish_many(ids=ids, publish_model=False)
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"published"})
# reset
res = self.api.tasks.reset_many(
ids=ids, delete_output_models=True, return_file_urls=True, force=True
)
self._assert_succeeded(res, tasks)
self.assertEqual(sum(t.deleted_models for t in res.succeeded), 2)
self.assertEqual(
set(url for t in res.succeeded for url in t.urls.model_urls),
{"uri_0", "uri_1"},
)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual({t.status for t in data}, {"created"})
# archive/unarchive
res = self.api.tasks.archive_many(ids=ids)
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertTrue(all("archived" in t.system_tags for t in data))
res = self.api.tasks.unarchive_many(ids=ids)
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertFalse(any("archived" in t.system_tags for t in data))
# delete
res = self.api.tasks.delete_many(
ids=ids, delete_output_models=True, return_file_urls=True
)
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks
self.assertEqual(data, [])
def test_models(self):
uris = [f"file:///{i}" for i in range(2)]
models = [self._temp_model(uri=uri) for uri in uris]
missing_id = db_id()
ids = [*models, missing_id]
# publish
task = self._temp_task()
self.api.models.edit(model=ids[0], ready=False, task=task)
self.api.tasks.add_or_update_model(
task=task, name="output", type="input", model=ids[0]
)
res = self.api.models.publish_many(
ids=ids, publish_task=True, force_publish_task=True
)
self._assert_succeeded(res, [ids[0]])
self.assertEqual(res.succeeded[0].published_task.id, task)
self._assert_failed(res, [ids[1], missing_id])
# archive/unarchive
res = self.api.models.archive_many(ids=ids)
self._assert_succeeded(res, models)
self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertTrue(all("archived" in m.system_tags for m in data))
res = self.api.models.unarchive_many(ids=ids)
self._assert_succeeded(res, models)
self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertFalse(any("archived" in m.system_tags for m in data))
# delete
res = self.api.models.delete_many(ids=[*models, missing_id], force=True)
self._assert_succeeded(res, models)
self.assertEqual(set(m.url for m in res.succeeded), set(uris))
self._assert_failed(res, [missing_id])
data = self.api.models.get_all_ex(id=ids).models
self.assertEqual(data, [])
def _assert_succeeded(self, res, succeeded_ids):
self.assertEqual(set(f.id for f in res.succeeded), set(succeeded_ids))
def _assert_failed(self, res, failed_ids):
self.assertEqual(set(f.id for f in res.failed), set(failed_ids))
def _temp_model(self, **kwargs):
self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={})
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
def _temp_task(self):
return self.create_temp(
service="tasks", type="testing", name=self.name, input=dict(view={}),
)
def _temp_task_model(self, task, **kwargs) -> str:
model = self._temp_model(ready=False, task=task, **kwargs)
self.api.tasks.add_or_update_model(
task=task, name="output", type="output", model=model
)
return model

View File

@@ -28,7 +28,9 @@ class TestEntityOrdering(TestService):
self._assertGetTasksWithOrdering(order_by="comment")
# sort by parameter which type is not part of db schema
self._assertGetTasksWithOrdering(order_by="execution.parameters.test")
self._assertGetTasksWithOrdering(
order_by="execution.parameters.test", valid_order=False
)
def test_order_with_paging(self):
order_field = "started"
@@ -97,7 +99,9 @@ class TestEntityOrdering(TestService):
return val
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
def _assertGetTasksWithOrdering(
self, order_by: str = None, valid_order=True, **kwargs
):
tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=[order_by] if isinstance(order_by, str) else order_by,
@@ -105,14 +109,16 @@ class TestEntityOrdering(TestService):
**kwargs,
).tasks
self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks))
if order_by:
if order_by and valid_order:
# test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
field_vals = [
self._get_value_for_path(t, field_name.split(".")) for t in tasks
]
self._assertSorted(
field_vals,
ascending=not order_by.startswith("-"),
is_numeric=field_name.startswith("execution.parameters.")
is_numeric=field_name.startswith("execution.parameters."),
)
def _create_tasks(self):

View File

@@ -0,0 +1,74 @@
from functools import partial
from typing import Sequence
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
def test_queue_metas(self):
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
self._test_meta_operations(
service=self.api.queues, entity="queue", _id=queue_id
)
def test_models_metas(self):
service = self.api.models
entity = "model"
model_id = self._temp_model("TestMetadata", metadata=self.meta1)
self._test_meta_operations(
service=self.api.models, entity="model", _id=model_id
)
model_id = self._temp_model("TestMetadata1")
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,
):
assert_meta = partial(self._assertMeta, service=service, entity=entity)
assert_meta(_id=_id, meta=self.meta1)
meta2 = [
{"key": "test1", "type": "str", "value": "data1"},
{"key": "test2", "type": "str", "value": "data2"},
{"key": "test3", "type": "str", "value": "data3"},
]
service.update(**{entity: _id, "metadata": meta2})
assert_meta(_id=_id, meta=meta2)
updates = [
{"key": "test2", "type": "int", "value": "10"},
{"key": "test3", "type": "int", "value": "20"},
{"key": "test4", "type": "array", "value": "xxx,yyy"},
{"key": "test5", "type": "array", "value": "zzz"},
]
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=[meta2[0], *updates])
res = service.delete_metadata(
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
)
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=meta2[:1])
def _assertMeta(
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
):
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
self.assertEqual(res.metadata, meta)
def _temp_queue(self, name, **kwargs):
return self.create_temp("queues", name=name, **kwargs)
def _temp_model(self, name: str, **kwargs):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)

View File

@@ -0,0 +1,267 @@
from time import sleep
from typing import Sequence, Optional, Tuple
from boltons.iterutils import first
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import id as db_id
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService
class TestSubProjects(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
api_key=config.get("apiclient.user_auth_only"),
secret_key=config.get("apiclient.user_auth_only_secret"),
base_url=f"http://localhost:8008/v2.13",
)
child = self._temp_project(name="Aggregation/Pr1", client=user2_client)
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
self.assertEqual(child_project.parent.id, project)
user = self.api.users.get_current_user().user.id
# test aggregations on project with empty subprojects
res = self.api.users.get_all_ex(active_in_projects=[project])
self.assertEqual(res.users, [])
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
self.assertEqual(res.projects, [])
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual(res.frameworks, [])
res = self.api.tasks.get_types(projects=[project])
self.assertEqual(res.types, [])
res = self.api.projects.get_task_parents(projects=[project])
self.assertEqual(res.parents, [])
# test aggregations with non-empty subprojects
task1 = self._temp_task(project=child)
self._temp_task(project=child, parent=task1)
framework = "Test framework"
self._temp_model(project=child, framework=framework)
res = self.api.users.get_all_ex(active_in_projects=[project])
self._assert_ids(res.users, [user])
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
self._assert_ids(res.projects, [project])
res = self.api.projects.get_task_parents(projects=[project])
self._assert_ids(res.parents, [task1])
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual(res.frameworks, [framework])
res = self.api.tasks.get_types(projects=[project])
self.assertEqual(res.types, ["testing"])
def _assert_ids(self, actual: Sequence[dict], expected: Sequence[str]):
self.assertEqual([a["id"] for a in actual], expected)
def test_project_operations(self):
# create
with self.api.raises(errors.bad_request.InvalidProjectName):
self._temp_project(name="/")
project1 = self._temp_project(name="Root1/Pr1")
project1_child = self._temp_project(name="Root1/Pr1/Pr2")
with self.api.raises(errors.bad_request.ExpectedUniqueData):
self._temp_project(name="Root1/Pr1/Pr2")
# update
with self.api.raises(errors.bad_request.CannotUpdateProjectLocation):
self.api.projects.update(project=project1, name="Root2/Pr2")
res = self.api.projects.update(project=project1, name="Root1/Pr2")
self.assertEqual(res.updated, 1)
res = self.api.projects.get_by_id(project=project1_child)
self.assertEqual(res.project.name, "Root1/Pr2/Pr2")
# move
res = self.api.projects.move(project=project1, new_location="Root2")
self.assertEqual(res.moved, 2)
res = self.api.projects.get_by_id(project=project1_child)
self.assertEqual(res.project.name, "Root2/Pr2/Pr2")
# merge
project_with_task, (active, archived) = self._temp_project_with_tasks(
"Root1/Pr3/Pr4"
)
project1_parent = self._getProjectParent(project1)
self._assertTags(project1_parent, tags=[], system_tags=[])
self._assertTags(project1_parent, tags=[], system_tags=[])
project_with_task_parent = self._getProjectParent(project_with_task)
self._assertTags(project_with_task_parent)
# self._assertTags(project_id=None)
merge_source = self.api.projects.get_by_id(
project=project_with_task
).project.parent
res = self.api.projects.merge(
project=merge_source, destination_project=project1
)
self.assertEqual(res.moved_entities, 0)
self.assertEqual(res.moved_projects, 1)
res = self.api.projects.get_by_id(project=project_with_task)
self.assertEqual(res.project.name, "Root2/Pr2/Pr4")
with self.api.raises(errors.bad_request.InvalidProjectId):
self.api.projects.get_by_id(project=merge_source)
self._assertTags(project1_parent)
self._assertTags(project1)
self._assertTags(project_with_task_parent, tags=[], system_tags=[])
# self._assertTags(project_id=None)
# delete
with self.api.raises(errors.bad_request.ProjectHasTasks):
self.api.projects.delete(project=project1)
res = self.api.projects.delete(project=project1, force=True)
self.assertEqual(res.deleted, 3)
self.assertEqual(res.disassociated_tasks, 2)
res = self.api.tasks.get_by_id(task=active).task
self.assertIsNone(res.get("project"))
for p_id in (project1, project1_child, project_with_task):
with self.api.raises(errors.bad_request.InvalidProjectId):
self.api.projects.get_by_id(project=p_id)
self._assertTags(project1_parent, tags=[], system_tags=[])
# self._assertTags(project_id=None, tags=[], system_tags=[])
def _getProjectParent(self, project_id: str):
return self.api.projects.get_all_ex(id=[project_id]).projects[0].parent.id
def _assertTags(
self,
project_id: Optional[str],
tags: Sequence[str] = ("test",),
system_tags: Sequence[str] = (EntityVisibility.archived.value,),
):
if project_id:
res = self.api.projects.get_task_tags(
projects=[project_id], include_system=True
)
else:
res = self.api.organization.get_tags(include_system=True)
self.assertEqual(set(res.tags), set(tags))
self.assertEqual(set(res.system_tags), set(system_tags))
def test_get_all_search_options(self):
project1 = self._temp_project(name="project1")
project2 = self._temp_project(name="project1/project2")
self._temp_project(name="project3")
# local search finds only at the specified level
res = self.api.projects.get_all_ex(
name="project1", shallow_search=True
).projects
self.assertEqual([p.id for p in res], [project1])
res = self.api.projects.get_all_ex(name="project1", parent=[project1]).projects
self.assertEqual([p.id for p in res], [project2])
# global search finds all or below the specified level
res = self.api.projects.get_all_ex(name="project1").projects
self.assertEqual(set(p.id for p in res), {project1, project2})
project4 = self._temp_project(name="project1/project2/project1")
res = self.api.projects.get_all_ex(name="project1", parent=[project2]).projects
self.assertEqual([p.id for p in res], [project4])
self.api.projects.delete(project=project1, force=True)
def test_get_all_with_check_own_contents(self):
project1, _ = self._temp_project_with_tasks(name="project1x")
project2 = self._temp_project(name="project2x")
self._temp_project_with_tasks(name="project2x/project22")
self._temp_model(project=project1)
res = self.api.projects.get_all_ex(
id=[project1, project2], check_own_contents=True
).projects
res1 = next(p for p in res if p.id == project1)
self.assertEqual(res1.own_tasks, 2)
self.assertEqual(res1.own_models, 1)
res2 = next(p for p in res if p.id == project2)
self.assertEqual(res2.own_tasks, 0)
self.assertEqual(res2.own_models, 0)
def test_get_all_with_stats(self):
project4, _ = self._temp_project_with_tasks(name="project1/project3/project4")
project5, _ = self._temp_project_with_tasks(name="project1/project3/project5")
project2 = self._temp_project(name="project2")
res = self.api.projects.get_all(shallow_search=True).projects
self.assertTrue(any(p for p in res if p.id == project2))
self.assertFalse(any(p for p in res if p.id in [project4, project5]))
project1 = first(p.id for p in res if p.name == "project1")
res = self.api.projects.get_all_ex(
id=[project1, project2], include_stats=True
).projects
self.assertEqual(set(p.id for p in res), {project1, project2})
res1 = next(p for p in res if p.id == project1)
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(
{sp.name for sp in res1.sub_projects},
{
"project1/project3",
"project1/project3/project4",
"project1/project3/project5",
},
)
res2 = next(p for p in res if p.id == project2)
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks):
"""Imitate 1 second of running"""
for task_id in tasks:
self.api.tasks.started(task=task_id)
sleep(1)
for task_id in tasks:
self.api.tasks.stopped(task=task_id)
def _temp_project_with_tasks(self, name) -> Tuple[str, Tuple[str, str]]:
pr_id = self._temp_project(name=name)
task_active = self._temp_task(project=pr_id)
task_archived = self._temp_task(
project=pr_id, system_tags=[EntityVisibility.archived.value], tags=["test"]
)
self._run_tasks(task_active, task_archived)
return pr_id, (task_active, task_archived)
delete_params = dict(can_fail=True, force=True)
def _temp_project(self, name, client=None, **kwargs):
return self.create_temp(
"projects",
delete_params=self.delete_params,
name=name,
description="",
client=client,
**kwargs,
)
def _temp_task(self, **kwargs):
return self.create_temp(
"tasks",
delete_params=self.delete_params,
type="testing",
name=db_id(),
input=dict(view=dict()),
**kwargs,
)
def _temp_model(self, **kwargs):
return self.create_temp(
service="models",
delete_params=self.delete_params,
name="test",
uri="file:///a",
labels={},
**kwargs,
)

View File

@@ -47,9 +47,9 @@ class TestTasksArtifacts(TestService):
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
new_artifacts = [
dict(key="x", type="str", uri="x_test"),
dict(key="y", type="int", uri="y_test"),
dict(key="z", type="int", uri="y_test"),
dict(key="x", type="str", uri="x_test", mode="input"),
dict(key="y", type="int", uri="y_test", mode="input"),
dict(key="z", type="int", uri="y_test", mode="input"),
]
new_task = self.api.tasks.clone(
task=task, execution_overrides={"artifacts": new_artifacts}

View File

@@ -1,6 +1,5 @@
from functools import partial
from typing import Sequence
from typing import Sequence, Mapping
from apiserver.es_factory import es_factory
from apiserver.tests.automated import TestService
@@ -128,6 +127,96 @@ class TestTaskDebugImages(TestService):
def test_task_debug_images(self):
task = self._temp_task()
# test empty
res = self.api.events.debug_images(metrics=[{"task": task}], iters=5)
self.assertFalse(res.metrics[0].iterations)
res = self.api.events.debug_images(
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
)
self.assertFalse(res.metrics[0].iterations)
# test not empty
metrics = {
"Metric1": ["Variant1", "Variant2"],
"Metric2": ["Variant3", "Variant4"],
}
events = [
self._create_task_event(
task=task,
iteration=1,
metric=metric,
variant=variant,
url=f"{metric}_{variant}_{1}",
)
for metric, variants in metrics.items()
for variant in variants
]
self.send_batch(events)
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=1
)
# test refresh
update = {
"Metric2": ["Variant3", "Variant4", "Variant5"],
"Metric3": ["VariantA", "VariantB"],
}
events = [
self._create_task_event(
task=task,
iteration=2,
metric=metric,
variant=variant,
url=f"{metric}_{variant}_{2}",
)
for metric, variants in update.items()
for variant in variants
]
self.send_batch(events)
# without refresh the metric states are not updated
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
)
# with refresh there are new metrics and existing ones are updated
self._assertTaskMetrics(
task=task,
expected_metrics=update,
iterations=1,
scroll_id=scroll_id,
refresh=True,
)
pass
def _assertTaskMetrics(
self,
task: str,
expected_metrics: Mapping[str, Sequence[str]],
iterations,
scroll_id: str = None,
refresh=False,
) -> str:
res = self.api.events.debug_images(
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
)
if not iterations:
self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
for metric_data in res.metrics:
self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations:
self.assertEqual(
set((e.metric, e.variant) for e in it_data.events), expected_variants
)
return res.scroll_id
def test_get_debug_images_navigation(self):
task = self._temp_task()
metric = "Metric1"
variants = [("Variant1", 7), ("Variant2", 4)]
iterations = 10
@@ -136,7 +225,7 @@ class TestTaskDebugImages(TestService):
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}], iters=5,
)
self.assertFalse(res.metrics)
self.assertFalse(res.metrics[0].iterations)
# create events
events = [
@@ -195,7 +284,7 @@ class TestTaskDebugImages(TestService):
expected_page: int,
iters: int = 5,
**extra_params,
):
) -> str:
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=iters,
@@ -204,7 +293,6 @@ class TestTaskDebugImages(TestService):
)
data = res["metrics"][0]
self.assertEqual(data["task"], task)
self.assertEqual(data["metric"], metric)
left_iterations = max(0, max(unique_images) - expected_page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:

View File

@@ -204,7 +204,9 @@ class TestTaskEvents(TestService):
self.send_batch(events)
for key in None, "iter", "timestamp", "iso_time":
with self.subTest(key=key):
data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key)
data = self.api.events.scalar_metrics_iter_histogram(
task=task, **(dict(key=key) if key is not None else {})
)
self.assertIn(metric, data)
self.assertIn(variant, data[metric])
self.assertIn("x", data[metric][variant])

View File

@@ -105,7 +105,9 @@ class TestTasksHyperparams(TestService):
)
# clone task
new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id
new_task = self.api.tasks.clone(
task=task, new_task_hyperparams=new_params_dict
).id
try:
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
self.assertEqual(new_params, res.hyperparams)
@@ -123,7 +125,9 @@ class TestTasksHyperparams(TestService):
task=task, hyperparams=[dict(section="test")]
)
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True
task=task,
hyperparams=[dict(section="test", name="x", value="123")],
force=True,
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")], force=True
@@ -146,7 +150,12 @@ class TestTasksHyperparams(TestService):
return [
dict(section="Args", name=k, value=str(v), type="legacy")
if not k.startswith("TF_DEFINE/")
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
else dict(
section="TF_DEFINE",
name=k[len("TF_DEFINE/") :],
value=str(v),
type="legacy",
)
for k, v in legacy.items()
]
@@ -168,6 +177,7 @@ class TestTasksHyperparams(TestService):
new_config = [
dict(name="param$1", type="type1", value="10"),
dict(name="param/2", type="type1", value="20"),
dict(name="param_empty", type="type1", value=""),
]
new_config_dict = self._config_dict_from_list(new_config)
task, _ = self.new_task(
@@ -188,7 +198,14 @@ class TestTasksHyperparams(TestService):
# names
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", "param$1", "param/2"], res.names)
self.assertEqual(
["design", *[c["name"] for c in new_config if c["value"]]], res.names
)
res = self.api.tasks.get_configuration_names(
tasks=[task], skip_empty=False
).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", *[c["name"] for c in new_config]], res.names)
# returned as one list with names filtering
res = self.api.tasks.get_configurations(
@@ -216,14 +233,14 @@ class TestTasksHyperparams(TestService):
# delete
new_to_delete = self._get_config_keys(new_config[1:])
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config[:1], res.configuration)
# clone task
new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id
new_task = self.api.tasks.clone(
task=task, new_task_configuration=new_config_dict
).id
try:
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
self.assertEqual(new_config, res.configuration)
@@ -233,13 +250,9 @@ class TestTasksHyperparams(TestService):
# edit/delete of running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_configuration(
task=task, configuration=new_config
)
self.api.tasks.edit_configuration(task=task, configuration=new_config)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
self.api.tasks.edit_configuration(
task=task, configuration=new_config, force=True
)
@@ -292,7 +305,9 @@ class TestTasksHyperparams(TestService):
task_id, _ = self.new_task(
execution={"parameters": legacy_params, "model_desc": legacy_config}
)
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
config = self._config_dict_from_list(
self._new_config_from_legacy(legacy_config)
)
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
old_api = APIClient(base_url="http://localhost:8008/v2.8")
@@ -304,7 +319,10 @@ class TestTasksHyperparams(TestService):
modified_params = {"legacy.2": "val2"}
modified_config = {"design": "by"}
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
old_api.tasks.edit(
task=task_id,
execution=dict(parameters=modified_params, model_desc=modified_config),
)
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(modified_params, task.execution.parameters)
self.assertEqual(modified_config, task.execution.model_desc)

View File

@@ -0,0 +1,114 @@
from copy import deepcopy
from typing import Sequence, Optional
from packaging.version import parse
from apiserver.tests.automated import TestService
class TestTaskModels(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
def test_new_apis(self):
# no models
empty_task = self.new_task()
self.assertModels(empty_task, [], [])
id1, id2 = self.new_model("model1"), self.new_model("model2")
input_models = [
{"name": "input1", "model": id1},
{"name": "input2", "model": id2},
]
output_models = [
{"name": "output1", "model": "id3"},
{"name": "output2", "model": "id4"},
]
# task creation with models
task = self.new_task(models={"input": input_models, "output": output_models})
self.assertModels(task, input_models, output_models)
# add_or_update existing model
res = self.api.tasks.add_or_update_model(
task=task, name="input1", type="input", model="Test"
)
self.assertEqual(res.updated, 1)
modified_input = deepcopy(input_models)
modified_input[0]["model"] = "Test"
self.assertModels(task, modified_input, output_models)
# add_or_update new mode
res = self.api.tasks.add_or_update_model(
task=task, name="output3", type="output", model="TestOutput"
)
self.assertEqual(res.updated, 1)
modified_output = deepcopy(output_models)
modified_output.append({"name": "output3", "model": "TestOutput"})
self.assertModels(task, modified_input, modified_output)
# task editing
self.api.tasks.edit(
task=task, models={"input": input_models, "output": output_models}
)
self.assertModels(task, input_models, output_models)
# delete models
res = self.api.tasks.delete_models(
task=task,
models=[
{"name": "input1", "type": "input"},
{"name": "input2", "type": "input"},
{"name": "output1", "type": "output"},
{"name": "not_existing", "type": "output"},
]
)
self.assertEqual(res.updated, 1)
self.assertModels(task, [], output_models[1:])
def assertModels(
self, task_id: str, input_models: Sequence[dict], output_models: Sequence[dict],
):
def get_model_id(model: dict) -> Optional[str]:
if not model:
return None
id_ = model.get("model")
if isinstance(id_, str):
return id_
if id_ is None or id_ == {}:
return None
return id_.get("id")
def compare_models(actual: Sequence[dict], expected: Sequence[dict]):
self.assertEqual(
[(m["name"], get_model_id(m)) for m in actual],
[(m["name"], m["model"]) for m in expected],
)
for task in (
self.api.tasks.get_all_ex(id=task_id).tasks[0],
self.api.tasks.get_all(id=task_id).tasks[0],
self.api.tasks.get_by_id(task=task_id).task,
):
compare_models(task.models.input, input_models)
compare_models(task.models.output, output_models)
if self._version < parse("2.13"):
self.assertEqual(
get_model_id(task.execution),
input_models[0]["model"] if input_models else None,
)
self.assertEqual(
get_model_id(task.output),
output_models[-1]["model"] if output_models else None,
)
def new_task(self, **kwargs):
self.update_missing(
kwargs, type="testing", name="test task models", input=dict(view=dict())
)
return self.create_temp("tasks", **kwargs)
def new_model(self, name: str, **kwargs):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)

Some files were not shown because too many files have changed in this diff Show More