diff --git a/LICENSE b/LICENSE
index 24d07bf..19931a1 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
- Copyright © 2019 allegro.ai, Inc.
+ Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
diff --git a/README.md b/README.md
index f85bc9e..13a971c 100644
--- a/README.md
+++ b/README.md
@@ -7,42 +7,15 @@
[](https://img.shields.io/badge/license-SSPL-green.svg)
[](https://img.shields.io/badge/python-3.9-blue.svg)
-[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
-[](https://artifacthub.io/packages/search?repo=allegroai)
+[](https://img.shields.io/github/release-pre/clearml/trains-server.svg)
+[](https://artifacthub.io/packages/search?repo=clearml)
----
-
-
-**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
-
-
-
-According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
-supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
-due to Elasticsearch’s usage of the Java Security Manager.
-
-**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
-
-As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
-
-While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
-(which in any case **does not permit access to data within the Elasticsearch cluster**),
-we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
-[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
-
-**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
-ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
-
-**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
-
----
-
## ClearML Server
#### *Formerly known as Trains Server*
-The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
+The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/clearml/clearml).
It allows multiple users to collaborate and manage their experiments.
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
@@ -99,8 +72,10 @@ Launch The **ClearML Server** in any of the following formats:
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
-- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
-
+- Kubernetes
+ - [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
+ - Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
+
## Connecting ClearML to your ClearML Server
In order to set up the **ClearML** client to work with your **ClearML Server**:
@@ -118,16 +93,13 @@ In order to set up the **ClearML** client to work with your **ClearML Server**:
files_server: "http://localhost:8081"
}
-
-> [!NOTE]
->
-> If you have set up your **ClearML Server** in a sub-domain configuration, then there is no need to specify a port number,
+**Note**: If you have set up your **ClearML Server** in a sub-domain configuration, then there is no need to specify a port number,
it will be inferred from the http/s scheme.
-After launching the ClearML Server and configuring the **ClearML** client to use the ClearML Server,
-you can use [ClearML](https://github.com/allegroai/clearml) in your experiments and view them in your ClearML Server web server,
+After launching the **ClearML Server** and configuring the **ClearML** client to use the **ClearML Server**,
+you can [use](https://github.com/clearml/clearml) **ClearML** in your experiments and view them in your **ClearML Server** web server,
for example http://localhost:8080.
-For more information about the ClearML client, see [**ClearML**](https://github.com/allegroai/clearml).
+For more information about the ClearML client, see [**ClearML**](https://github.com/clearml/clearml).
## ClearML-Agent Services
@@ -144,11 +116,9 @@ increased data transparency)
ClearML-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
Every task launched by ClearML-Agent Services will be registered as a new node in the system,
providing tracking and transparency capabilities.
-You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/allegroai/clearml-agent#clearml-agent-services-mode-)
+You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/clearml/clearml-agent#clearml-agent-services-mode-)
-> [!NOTE]
->
-> It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
+**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
## Advanced Functionality
@@ -169,7 +139,7 @@ To restart the **ClearML Server**, you must first stop the containers, and then
## Upgrading
-**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker/docker-compose.yml).
+**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/clearml/trains-server/blob/master/docker/docker-compose.yml).
We strongly encourage you to keep your **ClearML Server** up to date, by keeping up with the current release.
**Note**: The following upgrade instructions use the Linux OS as an example.
@@ -202,7 +172,7 @@ To upgrade your existing **ClearML Server** deployment:
1. Download the latest `docker-compose.yml` file.
```bash
- curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker/docker-compose.yml -o docker-compose.yml
+ curl https://raw.githubusercontent.com/clearml/trains-server/master/docker/docker-compose.yml -o docker-compose.yml
```
1. Configure the ClearML-Agent Services (not supported on Windows installation).
@@ -227,10 +197,10 @@ To upgrade your existing **ClearML Server** deployment:
## Community & Support
-If you have any questions, look at the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
+If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
-For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
+For feature requests or bug reports, please use [GitHub issues](https://github.com/clearml/clearml-server/issues).
Additionally, you can always find us at *clearml@allegro.ai*
diff --git a/apiserver/LICENSE b/apiserver/LICENSE
index 24d07bf..19931a1 100644
--- a/apiserver/LICENSE
+++ b/apiserver/LICENSE
@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
- Copyright © 2019 allegro.ai, Inc.
+ Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
diff --git a/apiserver/apierrors/errors.conf b/apiserver/apierrors/errors.conf
index ca78c81..da8d27b 100644
--- a/apiserver/apierrors/errors.conf
+++ b/apiserver/apierrors/errors.conf
@@ -84,6 +84,7 @@
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
413: ["project_has_pipelines", "project has associated pipelines with active controllers"]
+ 414: ["public_project_exists", "Cannot create project. Public project with the same name already exists"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]
@@ -106,6 +107,11 @@
1004: ["worker_not_registered", "worker is not registered"]
1005: ["worker_stats_not_found", "worker stats not found"]
+ # Serving
+ 1050: ["invalid_container_id", "invalid container id"]
+ 1051: ["container_not_registered", "container is not registered"]
+ 1052: ["no_containers_for_url", "no container instances found for serice url"]
+
1104: ["invalid_scroll_id", "Invalid scroll id"]
}
diff --git a/apiserver/apimodels/__init__.py b/apiserver/apimodels/__init__.py
index 9c3d804..49123c7 100644
--- a/apiserver/apimodels/__init__.py
+++ b/apiserver/apimodels/__init__.py
@@ -1,10 +1,11 @@
from enum import Enum
-from typing import Union, Type, Iterable
+from numbers import Number
+from typing import Union, Type, Iterable, Mapping
import jsonmodels.errors
import six
from jsonmodels import fields
-from jsonmodels.fields import _LazyType, NotSet
+from jsonmodels.fields import _LazyType, NotSet, EmbeddedField
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from mongoengine.base import BaseDocument
@@ -40,6 +41,34 @@ def make_default(field_cls, default_value):
return _FieldWithDefault
+class OneOfEmbeddedField(EmbeddedField):
+ def __init__(
+ self,
+ *args,
+ discriminator_property: str,
+ discriminator_mapping: Mapping[str, type],
+ **kwargs,
+ ):
+ self.discriminator_property = discriminator_property
+ self.discriminator_mapping = discriminator_mapping
+ model_types = tuple(set(self.discriminator_mapping.values()))
+
+ super().__init__(model_types, *args, **kwargs)
+
+ def parse_value(self, value):
+ """Parse value to proper model type."""
+ if not isinstance(value, dict) or self.discriminator_property not in value:
+ return super().parse_value(value)
+
+ property_value = value.get(self.discriminator_property)
+ embed_type = self.discriminator_mapping.get(property_value)
+ if not embed_type:
+ raise jsonmodels.errors.ValidationError(
+ f"Could not find type matching discriminator property value: {property_value}"
+ )
+ return embed_type(**value)
+
+
class ListField(fields.ListField):
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
if default is not NotSet and callable(default):
@@ -68,6 +97,15 @@ class ScalarField(fields.BaseField):
types = (str, int, float, bool)
+class SafeStringField(fields.StringField):
+ """String field that can also accept numbers as input"""
+ def parse_value(self, value):
+ if isinstance(value, Number):
+ value = str(value)
+
+ return super().parse_value(value)
+
+
class DictField(fields.BaseField):
types = (dict,)
@@ -115,9 +153,7 @@ class DictField(fields.BaseField):
if len(self.value_types) != 1:
tpl = 'Cannot decide which type to choose from "{types}".'
raise jsonmodels.errors.ValidationError(
- tpl.format(
- types=', '.join([t.__name__ for t in self.value_types])
- )
+ tpl.format(types=", ".join([t.__name__ for t in self.value_types]))
)
return self.value_types[0](**value)
@@ -179,7 +215,7 @@ class EnumField(fields.StringField):
*args,
required=False,
default=None,
- **kwargs
+ **kwargs,
):
choices = list(map(self.parse_value, values_or_type))
validator_cls = EnumValidator if required else NullableEnumValidator
@@ -202,7 +238,7 @@ class ActualEnumField(fields.StringField):
validators=None,
required=False,
default=None,
- **kwargs
+ **kwargs,
):
self.__enum = enum_class
self.types = (enum_class,)
@@ -215,7 +251,7 @@ class ActualEnumField(fields.StringField):
*args,
required=required,
validators=validators,
- **kwargs
+ **kwargs,
)
def parse_value(self, value):
diff --git a/apiserver/apimodels/queues.py b/apiserver/apimodels/queues.py
index e03b517..d9f87e1 100644
--- a/apiserver/apimodels/queues.py
+++ b/apiserver/apimodels/queues.py
@@ -17,6 +17,7 @@ class GetDefaultResp(Base):
class CreateRequest(Base):
name = StringField(required=True)
+ display_name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
@@ -47,6 +48,7 @@ class DeleteRequest(QueueRequest):
class UpdateRequest(QueueRequest):
name = StringField()
+ display_name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
@@ -56,6 +58,14 @@ class TaskRequest(QueueRequest):
task = StringField(required=True)
+class RemoveTaskRequest(TaskRequest):
+ update_task_status = BoolField(default=False)
+
+
+class AddTaskRequest(TaskRequest):
+ update_execution_queue = BoolField(default=True)
+
+
class MoveTaskRequest(TaskRequest):
count = IntField(default=1)
diff --git a/apiserver/apimodels/serving.py b/apiserver/apimodels/serving.py
new file mode 100644
index 0000000..64e39d9
--- /dev/null
+++ b/apiserver/apimodels/serving.py
@@ -0,0 +1,104 @@
+from enum import Enum
+from typing import Sequence
+
+from jsonmodels.models import Base
+from jsonmodels.fields import (
+ StringField,
+ EmbeddedField,
+ DateTimeField,
+ IntField,
+ FloatField,
+ BoolField,
+)
+from jsonmodels import validators
+from jsonmodels.validators import Min
+
+from apiserver.apimodels import ListField, JsonSerializableMixin, SafeStringField
+from apiserver.apimodels import ActualEnumField
+from apiserver.config_repo import config
+from .workers import MachineStats
+
+
+class ReferenceItem(Base):
+ type = StringField(
+ required=True,
+ validators=validators.Enum("app_id", "app_instance", "model", "task", "url"),
+ )
+ value = StringField(required=True)
+
+
+class ServingModel(Base):
+ container_id = StringField(required=True)
+ endpoint_name = StringField(required=True)
+ endpoint_url = StringField() # can be not existing yet at registration time
+ model_name = StringField(required=True)
+ model_source = StringField()
+ model_version = StringField()
+ preprocess_artifact = StringField()
+ input_type = StringField()
+ input_size = SafeStringField()
+ tags = ListField(str)
+ system_tags = ListField(str)
+ reference: Sequence[ReferenceItem] = ListField(ReferenceItem)
+
+
+class RegisterRequest(ServingModel):
+ timeout = IntField(
+ default=int(
+ config.get("services.serving.default_container_timeout_sec", 10 * 60)
+ ),
+ validators=[Min(1)],
+ )
+ """ registration timeout in seconds (default is 10min) """
+
+
+class UnregisterRequest(Base):
+ container_id = StringField(required=True)
+
+
+class StatusReportRequest(ServingModel):
+ uptime_sec = IntField()
+ requests_num = IntField()
+ requests_min = FloatField()
+ latency_ms = IntField()
+ machine_stats: MachineStats = EmbeddedField(MachineStats)
+
+
+class ServingContainerEntry(StatusReportRequest, JsonSerializableMixin):
+ key = StringField(required=True)
+ company_id = StringField(required=True)
+ ip = StringField()
+ register_time = DateTimeField(required=True)
+ register_timeout = IntField(required=True)
+ last_activity_time = DateTimeField(required=True)
+
+
+class GetEndpointDetailsRequest(Base):
+ endpoint_url = StringField(required=True)
+
+
+class MetricType(Enum):
+ requests = "requests"
+ requests_min = "requests_min"
+ latency_ms = "latency_ms"
+ cpu_count = "cpu_count"
+ gpu_count = "gpu_count"
+ cpu_util = "cpu_util"
+ gpu_util = "gpu_util"
+ ram_total = "ram_total"
+ ram_used = "ram_used"
+ ram_free = "ram_free"
+ gpu_ram_total = "gpu_ram_total"
+ gpu_ram_used = "gpu_ram_used"
+ gpu_ram_free = "gpu_ram_free"
+ network_rx = "network_rx"
+ network_tx = "network_tx"
+
+
+class GetEndpointMetricsHistoryRequest(Base):
+ from_date = FloatField(required=True, validators=Min(0))
+ to_date = FloatField(required=True, validators=Min(0))
+ interval = IntField(required=True, validators=Min(1))
+ endpoint_url = StringField(required=True)
+ metric_type = ActualEnumField(MetricType, default=MetricType.requests)
+ instance_charts = BoolField(default=True)
diff --git a/apiserver/apimodels/storage.py b/apiserver/apimodels/storage.py
new file mode 100644
index 0000000..b0c244a
--- /dev/null
+++ b/apiserver/apimodels/storage.py
@@ -0,0 +1,60 @@
+from jsonmodels.fields import StringField, BoolField, ListField, EmbeddedField
+from jsonmodels.models import Base
+from jsonmodels.validators import Enum
+
+
+class AWSBucketSettings(Base):
+ bucket = StringField()
+ subdir = StringField()
+ host = StringField()
+ key = StringField()
+ secret = StringField()
+ token = StringField()
+ multipart = BoolField(default=True)
+ acl = StringField()
+ secure = BoolField(default=True)
+ region = StringField()
+ verify = BoolField(default=True)
+ use_credentials_chain = BoolField(default=False)
+
+
+class AWSSettings(Base):
+ key = StringField()
+ secret = StringField()
+ region = StringField()
+ token = StringField()
+ use_credentials_chain = BoolField(default=False)
+ buckets = ListField(items_types=[AWSBucketSettings])
+
+
+class GoogleBucketSettings(Base):
+ bucket = StringField()
+ subdir = StringField()
+ project = StringField()
+ credentials_json = StringField()
+
+
+class GoogleSettings(Base):
+ project = StringField()
+ credentials_json = StringField()
+ buckets = ListField(items_types=[GoogleBucketSettings])
+
+
+class AzureContainerSettings(Base):
+ account_name = StringField()
+ account_key = StringField()
+ container_name = StringField()
+
+
+class AzureSettings(Base):
+ containers = ListField(items_types=[AzureContainerSettings])
+
+
+class SetSettingsRequest(Base):
+ aws = EmbeddedField(AWSSettings)
+ google = EmbeddedField(GoogleSettings)
+ azure = EmbeddedField(AzureSettings)
+
+
+class ResetSettingsRequest(Base):
+ keys = ListField([str], item_validators=[Enum("aws", "google", "azure")])
diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py
index 50a07d8..a96f017 100644
--- a/apiserver/apimodels/tasks.py
+++ b/apiserver/apimodels/tasks.py
@@ -109,6 +109,7 @@ class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
+ update_execution_queue = BoolField(default=True)
class DeleteRequest(UpdateRequest):
diff --git a/apiserver/apimodels/workers.py b/apiserver/apimodels/workers.py
index ba98503..d1f631d 100644
--- a/apiserver/apimodels/workers.py
+++ b/apiserver/apimodels/workers.py
@@ -86,6 +86,7 @@ class CurrentTaskEntry(IdNameEntry):
class QueueEntry(IdNameEntry):
+ display_name = StringField()
next_task = EmbeddedField(IdNameEntry)
num_tasks = IntField()
diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py
index 80d8392..e210534 100644
--- a/apiserver/bll/event/event_bll.py
+++ b/apiserver/bll/event/event_bll.py
@@ -41,7 +41,7 @@ from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
-from apiserver.database.model.task.task import Task, TaskStatus
+from apiserver.database.model.task.task import TaskStatus
from apiserver.redis_manager import redman
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get
@@ -201,6 +201,8 @@ class EventBLL(object):
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
for event in events:
+ x_axis_label = event.pop("x_axis_label", None)
+
# remove spaces from event type
event_type = event.get("type")
if event_type is None:
@@ -296,6 +298,7 @@ class EventBLL(object):
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
+ x_axis_label=x_axis_label,
)
actions.append(es_action)
@@ -319,6 +322,7 @@ class EventBLL(object):
if actions:
chunk_size = 500
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
+ # noinspection PyTypeChecker
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
@@ -430,7 +434,7 @@ class EventBLL(object):
return False
return True
- def _update_last_scalar_events_for_task(self, last_events, event):
+ def _update_last_scalar_events_for_task(self, last_events, event, x_axis_label=None):
"""
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
@@ -438,45 +442,47 @@ class EventBLL(object):
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
+ value = event.get("value")
+ if value is None:
+ return
+
metric = event.get("metric") or ""
variant = event.get("variant") or ""
-
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
last_event = last_events[metric_hash][variant_hash]
+ last_event["metric"] = metric
+ last_event["variant"] = variant
+ last_event["count"] = last_event.get("count", 0) + 1
+ last_event["total"] = last_event.get("total", 0) + value
+
event_iter = event.get("iter", 0)
event_timestamp = event.get("timestamp", 0)
- value = event.get("value")
- if value is not None and (
- (event_iter, event_timestamp)
- >= (
- last_event.get("iter", event_iter),
- last_event.get("timestamp", event_timestamp),
- )
+ if (event_iter, event_timestamp) >= (
+ last_event.get("iter", event_iter),
+ last_event.get("timestamp", event_timestamp),
):
- event_data = {
- k: event[k]
- for k in ("value", "metric", "variant", "iter", "timestamp")
- if k in event
- }
- last_event_min_value = last_event.get("min_value", value)
- last_event_min_value_iter = last_event.get("min_value_iter", event_iter)
- if value < last_event_min_value:
- event_data["min_value"] = value
- event_data["min_value_iter"] = event_iter
- else:
- event_data["min_value"] = last_event_min_value
- event_data["min_value_iter"] = last_event_min_value_iter
- last_event_max_value = last_event.get("max_value", value)
- last_event_max_value_iter = last_event.get("max_value_iter", event_iter)
- if value > last_event_max_value:
- event_data["max_value"] = value
- event_data["max_value_iter"] = event_iter
- else:
- event_data["max_value"] = last_event_max_value
- event_data["max_value_iter"] = last_event_max_value_iter
- last_events[metric_hash][variant_hash] = event_data
+ last_event["value"] = value
+ last_event["iter"] = event_iter
+ last_event["timestamp"] = event_timestamp
+ if x_axis_label is not None:
+ last_event["x_axis_label"] = x_axis_label
+
+ first_value_iter = last_event.get("first_value_iter")
+ if first_value_iter is None or event_iter < first_value_iter:
+ last_event["first_value"] = value
+ last_event["first_value_iter"] = event_iter
+
+ last_event_min_value = last_event.get("min_value")
+ if last_event_min_value is None or value < last_event_min_value:
+ last_event["min_value"] = value
+ last_event["min_value_iter"] = event_iter
+
+ last_event_max_value = last_event.get("max_value")
+ if last_event_max_value is None or value > last_event_max_value:
+ last_event["max_value"] = value
+ last_event["max_value_iter"] = event_iter
def _update_last_metric_events_for_task(self, last_events, event):
"""
@@ -659,7 +665,9 @@ class EventBLL(object):
Release the scroll once it is exhausted
"""
total_events = nested_get(es_res, ("hits", "total", "value"), default=0)
- events = [doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])]
+ events = [
+ doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])
+ ]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.clear_scroll(next_scroll_id)
@@ -1148,34 +1156,6 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
- @staticmethod
- def _validate_model_state(
- company_id: str, model_id: str, allow_locked: bool = False
- ):
- extra_msg = None
- query = Q(id=model_id, company=company_id)
- if not allow_locked:
- query &= Q(ready__ne=True)
- extra_msg = "or model published"
- res = Model.objects(query).only("id").first()
- if not res:
- raise errors.bad_request.InvalidModelId(
- extra_msg, company=company_id, id=model_id
- )
-
- @staticmethod
- def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
- extra_msg = None
- query = Q(id=task_id, company=company_id)
- if not allow_locked:
- query &= Q(status__nin=LOCKED_TASK_STATUSES)
- extra_msg = "or task published"
- res = Task.objects(query).only("id").first()
- if not res:
- raise errors.bad_request.InvalidTaskId(
- extra_msg, company=company_id, id=task_id
- )
-
@staticmethod
def _get_events_deletion_params(async_delete: bool) -> dict:
if async_delete:
@@ -1188,51 +1168,53 @@ class EventBLL(object):
return {"refresh": True}
- def delete_task_events(self, company_id, task_id, allow_locked=False, model=False):
- if model:
- self._validate_model_state(
- company_id=company_id,
- model_id=task_id,
- allow_locked=allow_locked,
- )
- else:
- self._validate_task_state(
- company_id=company_id, task_id=task_id, allow_locked=allow_locked
- )
- async_delete = async_task_events_delete
- if async_delete:
- total = self.events_iterator.count_task_events(
- event_type=EventType.all,
- company_id=company_id,
- task_ids=[task_id],
- )
- if total <= async_delete_threshold:
- async_delete = False
- es_req = {"query": {"term": {"task": task_id}}}
+ def delete_task_events(
+ self,
+ company_id,
+ task_ids: Union[str, Sequence[str]],
+ wait_for_delete: bool,
+ model=False,
+ ):
+ """
+ Delete task events. No check is done for tasks write access
+ so it should be checked by the calling code
+ """
+ if isinstance(task_ids, str):
+ task_ids = [task_ids]
+ deleted = 0
with translate_errors_context():
- es_res = delete_company_events(
- es=self.es,
- company_id=company_id,
- event_type=EventType.all,
- body=es_req,
- **self._get_events_deletion_params(async_delete),
- )
+ async_delete = async_task_events_delete and not wait_for_delete
+ if async_delete and len(task_ids) < 100:
+ total = self.events_iterator.count_task_events(
+ event_type=EventType.all,
+ company_id=company_id,
+ task_ids=task_ids,
+ )
+ if total <= async_delete_threshold:
+ async_delete = False
+ for tasks in chunked_iter(task_ids, 100):
+ es_req = {"query": {"terms": {"task": tasks}}}
+ es_res = delete_company_events(
+ es=self.es,
+ company_id=company_id,
+ event_type=EventType.all,
+ body=es_req,
+ **self._get_events_deletion_params(async_delete),
+ )
+ if not async_delete:
+ deleted += es_res.get("deleted", 0)
if not async_delete:
- return es_res.get("deleted", 0)
+ return deleted
def clear_task_log(
self,
company_id: str,
task_id: str,
- allow_locked: bool = False,
threshold_sec: int = None,
include_metrics: Sequence[str] = None,
exclude_metrics: Sequence[str] = None,
):
- self._validate_task_state(
- company_id=company_id, task_id=task_id, allow_locked=allow_locked
- )
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.task_log
):
@@ -1274,39 +1256,6 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
- def delete_multi_task_events(
- self, company_id: str, task_ids: Sequence[str], model=False
- ):
- """
- Delete multiple task events. No check is done for tasks write access
- so it should be checked by the calling code
- """
- deleted = 0
- with translate_errors_context():
- async_delete = async_task_events_delete
- if async_delete and len(task_ids) < 100:
- total = self.events_iterator.count_task_events(
- event_type=EventType.all,
- company_id=company_id,
- task_ids=task_ids,
- )
- if total <= async_delete_threshold:
- async_delete = False
- for tasks in chunked_iter(task_ids, 100):
- es_req = {"query": {"terms": {"task": tasks}}}
- es_res = delete_company_events(
- es=self.es,
- company_id=company_id,
- event_type=EventType.all,
- body=es_req,
- **self._get_events_deletion_params(async_delete),
- )
- if not async_delete:
- deleted += es_res.get("deleted", 0)
-
- if not async_delete:
- return deleted
-
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:
return
diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py
index c38b4c5..2ad3541 100644
--- a/apiserver/bll/event/event_metrics.py
+++ b/apiserver/bll/event/event_metrics.py
@@ -24,6 +24,8 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
+from apiserver.database.model.model import Model
+from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
log = config.logger(__file__)
@@ -43,6 +45,7 @@ class EventMetrics:
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
+ model_events: bool = False,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@@ -60,6 +63,7 @@ class EventMetrics:
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
+ model_events=model_events,
)
def _get_scalar_average_per_iter_core(
@@ -71,6 +75,7 @@ class EventMetrics:
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
+ model_events: bool = False,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@@ -102,7 +107,22 @@ class EventMetrics:
)
ret = defaultdict(dict)
+ if not metrics:
+ return ret
+
+ last_metrics = {}
+ cls_ = Model if model_events else Task
+ task = cls_.objects(id=task_id).only("last_metrics").first()
+ if task and task.last_metrics:
+ for m_data in task.last_metrics.values():
+ for v_data in m_data.values():
+ last_metrics[(v_data.metric, v_data.variant)] = v_data
+
for metric_key, metric_values in metrics:
+ for variant_key, data in metric_values.items():
+ last_metrics_data = last_metrics.get((metric_key, variant_key))
+ if last_metrics_data and last_metrics_data.x_axis_label is not None:
+ data["x_axis_label"] = last_metrics_data.x_axis_label
ret[metric_key].update(metric_values)
return ret
@@ -113,6 +133,7 @@ class EventMetrics:
samples,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
+ model_events: bool = False,
):
"""
Compare scalar metrics for different tasks per metric and variant
@@ -136,6 +157,7 @@ class EventMetrics:
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
run_parallel=False,
+ model_events=model_events,
)
task_ids, company_ids = zip(
*(
@@ -165,7 +187,7 @@ class EventMetrics:
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
- ) -> Mapping[str, dict]:
+ ) -> Mapping[str, Sequence[dict]]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
diff --git a/apiserver/bll/event/history_debug_image_iterator.py b/apiserver/bll/event/history_debug_image_iterator.py
index 26cbca4..88878be 100644
--- a/apiserver/bll/event/history_debug_image_iterator.py
+++ b/apiserver/bll/event/history_debug_image_iterator.py
@@ -183,7 +183,7 @@ class HistoryDebugImageIterator:
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
- "sort": [{"metric": order}, {"variant": order}],
+ "sort": [{"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
@@ -242,7 +242,7 @@ class HistoryDebugImageIterator:
]
es_req = {
"size": 1,
- "sort": [{"iter": order}, {"metric": order}, {"variant": order}],
+ "sort": [{"iter": order}, {"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
@@ -338,7 +338,7 @@ class HistoryDebugImageIterator:
es_req = {
"size": 1,
- "sort": {"iter": "desc"},
+ "sort": [{"iter": "desc"}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
diff --git a/apiserver/bll/event/metric_events_iterator.py b/apiserver/bll/event/metric_events_iterator.py
index 0c5e61f..95a0fb8 100644
--- a/apiserver/bll/event/metric_events_iterator.py
+++ b/apiserver/bll/event/metric_events_iterator.py
@@ -384,7 +384,8 @@ class MetricEventsIterator:
"aggs": {
"events": {
"top_hits": {
- "sort": self._get_same_variant_events_order()
+ "sort": self._get_same_variant_events_order(),
+ "size": 1,
}
}
},
diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py
index 3dc1ffe..5a4367b 100644
--- a/apiserver/bll/model/__init__.py
+++ b/apiserver/bll/model/__init__.py
@@ -6,7 +6,6 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
-from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
@@ -15,8 +14,6 @@ from .metadata import Metadata
class ModelBLL:
- event_bll = None
-
@classmethod
def get_company_model_by_id(
cls, company_id: str, model_id: str, only_fields=None
@@ -94,7 +91,7 @@ class ModelBLL:
@classmethod
def delete_model(
- cls, model_id: str, company_id: str, user_id: str, force: bool, delete_external_artifacts: bool = True,
+ cls, model_id: str, company_id: str, user_id: str, force: bool
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
@@ -147,34 +144,6 @@ class ModelBLL:
set__last_changed_by=user_id,
)
- delete_external_artifacts = delete_external_artifacts and config.get(
- "services.async_urls_delete.enabled", True
- )
- if delete_external_artifacts:
- from apiserver.bll.task.task_cleanup import (
- collect_debug_image_urls,
- collect_plot_image_urls,
- _schedule_for_delete,
- )
- urls = set()
- urls.update(collect_debug_image_urls(company_id, model_id))
- urls.update(collect_plot_image_urls(company_id, model_id))
- if model.uri:
- urls.add(model.uri)
- if urls:
- _schedule_for_delete(
- task_id=model_id,
- company=company_id,
- user=user_id,
- urls=urls,
- can_delete_folders=False,
- )
-
- if not cls.event_bll:
- from apiserver.bll.event import EventBLL
- cls.event_bll = EventBLL()
-
- cls.event_bll.delete_task_events(company_id, model_id, allow_locked=True, model=True)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@@ -217,7 +186,7 @@ class ModelBLL:
[
{
"$match": {
- "company": {"$in": [None, "", company]},
+ "company": {"$in": ["", company]},
"_id": {"$in": model_ids},
}
},
diff --git a/apiserver/bll/organization/__init__.py b/apiserver/bll/organization/__init__.py
index 18ee20e..a8fd329 100644
--- a/apiserver/bll/organization/__init__.py
+++ b/apiserver/bll/organization/__init__.py
@@ -1,4 +1,5 @@
from collections import defaultdict
+from datetime import datetime
from enum import Enum
from typing import Sequence, Dict, Type
@@ -28,6 +29,7 @@ class OrgBLL:
def edit_entity_tags(
self,
company_id,
+ user_id: str,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
@@ -47,13 +49,17 @@ class OrgBLL:
)
updated = 0
+ last_changed = {
+ "set__last_change": datetime.utcnow(),
+ "set__last_changed_by": user_id,
+ }
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
- add_to_set__tags=add_tags
+ add_to_set__tags=add_tags, **last_changed,
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
- pull_all__tags=remove_tags
+ pull_all__tags=remove_tags, **last_changed,
)
if not updated:
return 0
diff --git a/apiserver/bll/organization/tags_cache.py b/apiserver/bll/organization/tags_cache.py
index 8196290..0cf9d91 100644
--- a/apiserver/bll/organization/tags_cache.py
+++ b/apiserver/bll/organization/tags_cache.py
@@ -6,7 +6,6 @@ from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
-from apiserver.database.model import EntityVisibility
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -43,8 +42,8 @@ class _TagsCache:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project__in=project_ids_with_children([project]))
- else:
- query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
+ # else:
+ # query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)
diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py
index a500c19..911f58f 100644
--- a/apiserver/bll/project/project_bll.py
+++ b/apiserver/bll/project/project_bll.py
@@ -41,6 +41,7 @@ from .sub_projects import (
_ids_with_parents,
_get_project_depth,
ProjectsChildren,
+ _get_writable_project_from_name,
)
log = config.logger(__file__)
@@ -169,6 +170,7 @@ class ProjectBLL:
now = datetime.utcnow()
affected = set()
+ p: Project
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
@@ -183,6 +185,7 @@ class ProjectBLL:
new_name = fields.pop("name", None)
if new_name:
+ # noinspection PyTypeChecker
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
@@ -225,6 +228,18 @@ class ProjectBLL:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
name, location = _validate_project_name(name)
+
+ existing = _get_writable_project_from_name(
+ company=company,
+ name=name,
+ )
+ if existing:
+ raise errors.bad_request.ExpectedUniqueData(
+ replacement_msg="Project with the same name already exists",
+ name=name,
+ company=company,
+ )
+
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -810,7 +825,7 @@ class ProjectBLL:
}
def sum_runtime(
- a: Mapping[str, Mapping], b: Mapping[str, Mapping]
+ a: Mapping[str, dict], b: Mapping[str, dict]
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
@@ -1015,8 +1030,8 @@ class ProjectBLL:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
- else:
- query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
+ # else:
+ # query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
@@ -1046,7 +1061,7 @@ class ProjectBLL:
if not parent_ids:
return []
- parents = Task.get_many_with_join(
+ parents: Sequence[dict] = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
query_dict={"name": name} if name else None,
@@ -1101,7 +1116,7 @@ class ProjectBLL:
project_field: str = "project",
):
conditions = {
- "company": {"$in": [None, "", company]},
+ "company": {"$in": ["", company]},
project_field: {"$in": project_ids},
}
if users:
@@ -1153,7 +1168,7 @@ class ProjectBLL:
if or_conditions:
if len(or_conditions) == 1:
- conditions = next(iter(or_conditions))
+ conditions.update(next(iter(or_conditions)))
else:
conditions["$and"] = [c for c in or_conditions]
diff --git a/apiserver/bll/project/project_cleanup.py b/apiserver/bll/project/project_cleanup.py
index d10fe7c..d137b7c 100644
--- a/apiserver/bll/project/project_cleanup.py
+++ b/apiserver/bll/project/project_cleanup.py
@@ -8,10 +8,9 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.task.task_cleanup import (
- collect_debug_image_urls,
- collect_plot_image_urls,
TaskUrls,
- _schedule_for_delete,
+ schedule_for_delete,
+ delete_task_events_and_collect_urls,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
@@ -192,7 +191,7 @@ def delete_project(
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
- scheduled = _schedule_for_delete(
+ scheduled = schedule_for_delete(
task_id=project_id,
company=company,
user=user,
@@ -206,7 +205,6 @@ def delete_project(
deleted_models=deleted_models,
urls=TaskUrls(
model_urls=list(model_urls),
- event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
),
)
@@ -243,9 +241,6 @@ def _delete_tasks(
last_changed_by=user,
)
- event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
- company, task_ids
- )
artifact_urls = set()
for task in tasks:
if task.execution and task.execution.artifacts:
@@ -257,8 +252,11 @@ def _delete_tasks(
}
)
- event_bll.delete_multi_task_events(company, task_ids)
+ event_urls = delete_task_events_and_collect_urls(
+ company=company, task_ids=task_ids, wait_for_delete=False
+ )
deleted = tasks.delete()
+
return deleted, event_urls, artifact_urls
@@ -317,11 +315,10 @@ def _delete_models(
set__last_changed_by=user,
)
- event_urls = collect_debug_image_urls(company, model_ids) | collect_plot_image_urls(
- company, model_ids
- )
model_urls = {m.uri for m in models if m.uri}
-
- event_bll.delete_multi_task_events(company, model_ids, model=True)
+ event_urls = delete_task_events_and_collect_urls(
+ company=company, task_ids=model_ids, model=True, wait_for_delete=False
+ )
deleted = models.delete()
+
return deleted, event_urls, model_urls
diff --git a/apiserver/bll/project/project_queries.py b/apiserver/bll/project/project_queries.py
index 5fd05b9..b4c96d2 100644
--- a/apiserver/bll/project/project_queries.py
+++ b/apiserver/bll/project/project_queries.py
@@ -47,7 +47,7 @@ class ProjectQueries:
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
- return {"company": {"$in": [None, "", company_id]}}
+ return {"company": {"$in": ["", company_id]}}
return {"company": company_id}
diff --git a/apiserver/bll/project/sub_projects.py b/apiserver/bll/project/sub_projects.py
index ca0883d..d7564cf 100644
--- a/apiserver/bll/project/sub_projects.py
+++ b/apiserver/bll/project/sub_projects.py
@@ -2,6 +2,8 @@ import itertools
from datetime import datetime
from typing import Tuple, Optional, Sequence, Mapping
+from boltons.iterutils import first
+
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
@@ -96,10 +98,21 @@ def _get_writable_project_from_name(
"""
Return a project from name. If the project not found then return None
"""
- qs = Project.objects(company=company, name=name)
+ qs = Project.objects(company__in=[company, ""], name=name)
if _only:
+ if "company" not in _only:
+ _only = ["company", *_only]
qs = qs.only(*_only)
- return qs.first()
+ projects = list(qs)
+
+ if not projects:
+ return
+
+ project = first(p for p in projects if p.company == company)
+ if not project:
+ raise errors.bad_request.PublicProjectExists(name=name)
+
+ return project
ProjectsChildren = Mapping[str, Sequence[Project]]
diff --git a/apiserver/bll/query/builder.py b/apiserver/bll/query/builder.py
index 0581b5d..f28ea9c 100644
--- a/apiserver/bll/query/builder.py
+++ b/apiserver/bll/query/builder.py
@@ -9,20 +9,35 @@ RANGE_IGNORE_VALUE = -1
class Builder:
@staticmethod
- def dates_range(from_date: Union[int, float], to_date: Union[int, float]) -> dict:
+ def dates_range(
+ from_date: Optional[Union[int, float]] = None,
+ to_date: Optional[Union[int, float]] = None,
+ ) -> dict:
+ assert (
+ from_date or to_date
+ ), "range condition requires that at least one of from_date or to_date specified"
+ conditions = {}
+ if from_date:
+ conditions["gte"] = int(from_date)
+ if to_date:
+ conditions["lte"] = int(to_date)
return {
"range": {
"timestamp": {
- "gte": int(from_date),
- "lte": int(to_date),
+ **conditions,
"format": "epoch_second",
}
}
}
@staticmethod
- def terms(field: str, values: Iterable[str]) -> dict:
+ def terms(field: str, values: Iterable) -> dict:
+ if isinstance(values, str):
+ assert not isinstance(values, str), "apparently 'term' should be used here"
return {"terms": {field: list(values)}}
+ @staticmethod
+ def term(field: str, value) -> dict:
+ return {"term": {field: value}}
@staticmethod
def normalize_range(
diff --git a/apiserver/bll/queue/queue_bll.py b/apiserver/bll/queue/queue_bll.py
index 65d4957..097aa44 100644
--- a/apiserver/bll/queue/queue_bll.py
+++ b/apiserver/bll/queue/queue_bll.py
@@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime
-from typing import Sequence, Optional, Tuple, Union
+from typing import Sequence, Optional, Tuple, Union, Iterable
from elasticsearch import Elasticsearch
from mongoengine import Q
@@ -34,6 +34,7 @@ class QueueBLL(object):
def create(
company_id: str,
name: str,
+ display_name: str = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[dict] = None,
@@ -46,6 +47,7 @@ class QueueBLL(object):
company=company_id,
created=now,
name=name,
+ display_name=display_name,
tags=tags or [],
system_tags=system_tags or [],
metadata=metadata,
@@ -135,51 +137,78 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
- def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
+ def _update_task_status_on_removal_from_queue(
+ self,
+ company_id: str,
+ user_id: str,
+ task_ids: Iterable[str],
+ queue_id: str,
+ reason: str
+ ) -> Sequence[str]:
+ from apiserver.bll.task import ChangeStatusRequest
+ tasks = []
+ for task_id in task_ids:
+ try:
+ task = Task.get(
+ company=company_id,
+ id=task_id,
+ execution__queue=queue_id,
+ _only=[
+ "id",
+ "company",
+ "status",
+ "enqueue_status",
+ "project",
+ ],
+ )
+ if not task:
+ continue
+
+ tasks.append(task.id)
+ ChangeStatusRequest(
+ task=task,
+ new_status=task.enqueue_status or TaskStatus.created,
+ status_reason=reason,
+ status_message="",
+ user_id=user_id,
+ force=True,
+ ).execute(
+ enqueue_status=None,
+ unset__execution__queue=1,
+ )
+ except Exception as ex:
+ log.error(
+ f"Failed updating task {task_id} status on removal from queue: {queue_id}, {str(ex)}"
+ )
+
+ return tasks
+
+ def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> Sequence[str]:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
"""
- with translate_errors_context():
- queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
- if queue.entries:
- if not force:
- raise errors.bad_request.QueueNotEmpty(
- "use force=true to delete", id=queue_id
- )
- from apiserver.bll.task import ChangeStatusRequest
-
- for item in queue.entries:
- try:
- task = Task.get(
- company=company_id,
- id=item.task,
- _only=[
- "id",
- "company",
- "status",
- "enqueue_status",
- "project",
- ],
- )
- if not task:
- continue
-
- ChangeStatusRequest(
- task=task,
- new_status=task.enqueue_status or TaskStatus.created,
- status_reason="Queue was deleted",
- status_message="",
- user_id=user_id,
- force=True,
- ).execute(enqueue_status=None)
- except Exception as ex:
- log.exception(
- f"Failed dequeuing task {item.task} from queue: {queue_id}"
- )
-
+ queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
+ if not queue.entries:
queue.delete()
+ return []
+
+ if not force:
+ raise errors.bad_request.QueueNotEmpty(
+ "use force=true to delete", id=queue_id
+ )
+
+ tasks = self._update_task_status_on_removal_from_queue(
+ company_id=company_id,
+ user_id=user_id,
+ task_ids={item.task for item in queue.entries},
+ queue_id=queue_id,
+ reason=f"Queue {queue_id} was deleted",
+ )
+
+ queue.delete()
+ return tasks
def get_all(
self,
@@ -307,7 +336,36 @@ class QueueBLL(object):
return queue.entries[0]
- def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
+ def clear_queue(
+ self,
+ company_id: str,
+ user_id: str,
+ queue_id: str,
+ ):
+ queue = Queue.objects(company=company_id, id=queue_id).first()
+ if not queue:
+ raise errors.bad_request.InvalidQueueId(
+ queue=queue_id
+ )
+
+ if not queue.entries:
+ return []
+
+ tasks = self._update_task_status_on_removal_from_queue(
+ company_id=company_id,
+ user_id=user_id,
+ task_ids={item.task for item in queue.entries},
+ queue_id=queue_id,
+ reason=f"Queue {queue_id} was cleared",
+ )
+
+ queue.update(entries=[])
+ queue.reload()
+ self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
+
+ return tasks
+
+ def remove_task(self, company_id: str, user_id: str, queue_id: str, task_id: str, update_task_status: bool = False) -> int:
"""
Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
@@ -322,6 +380,14 @@ class QueueBLL(object):
res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
+ if res and update_task_status:
+ self._update_task_status_on_removal_from_queue(
+ company_id=company_id,
+ user_id=user_id,
+ task_ids=[task_id],
+ queue_id=queue_id,
+ reason=f"Task was removed from the queue {queue_id}",
+ )
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
@@ -461,7 +527,7 @@ class QueueBLL(object):
[
{
"$match": {
- "company": {"$in": [None, "", company]},
+ "company": {"$in": ["", company]},
"_id": queue_id,
}
},
diff --git a/apiserver/bll/serving/__init__.py b/apiserver/bll/serving/__init__.py
new file mode 100644
index 0000000..24f3cb5
--- /dev/null
+++ b/apiserver/bll/serving/__init__.py
@@ -0,0 +1,376 @@
+from datetime import datetime, timedelta, timezone
+from enum import Enum, auto
+from operator import attrgetter
+from time import time
+from typing import Optional, Sequence, Union
+
+import attr
+from boltons.iterutils import chunked_iter, bucketize
+from pyhocon import ConfigTree
+
+from apiserver.apimodels.serving import (
+ ServingContainerEntry,
+ RegisterRequest,
+ StatusReportRequest,
+)
+from apiserver.apimodels.workers import MachineStats
+from apiserver.apierrors import errors
+from apiserver.config_repo import config
+from apiserver.redis_manager import redman
+from .stats import ServingStats
+
+
+log = config.logger(__file__)
+
+
+class ServingBLL:
+ def __init__(self, redis=None):
+ self.conf = config.get("services.serving", ConfigTree())
+ self.redis = redis or redman.connection("workers")
+
+ @staticmethod
+ def _get_url_key(company: str, url: str):
+ return f"serving_url_{company}_{url}"
+
+ @staticmethod
+ def _get_container_key(company: str, container_id: str) -> str:
+ """Build redis key from company and container_id"""
+ return f"serving_container_{company}_{container_id}"
+
+ def _save_serving_container_entry(self, entry: ServingContainerEntry):
+ self.redis.setex(
+ entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
+ )
+
+ url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
+ expiration = int(time()) + entry.register_timeout
+ container_item = {entry.key: expiration}
+ self.redis.zadd(url_key, container_item)
+ # make sure that url set will not get stuck in redis
+ # indefinitely in case no more containers report to it
+ self.redis.expire(url_key, max(3600, entry.register_timeout))
+
+ def _get_serving_container_entry(
+ self, company_id: str, container_id: str
+ ) -> Optional[ServingContainerEntry]:
+ """
+ Get a container entry for the provided container ID.
+ """
+ key = self._get_container_key(company_id, container_id)
+ data = self.redis.get(key)
+ if not data:
+ return
+
+ try:
+ entry = ServingContainerEntry.from_json(data)
+ return entry
+ except Exception as e:
+ msg = "Failed parsing container entry"
+ log.exception(f"{msg}: {str(e)}")
+
+ def register_serving_container(
+ self,
+ company_id: str,
+ request: RegisterRequest,
+ ip: str = "",
+ ) -> ServingContainerEntry:
+ """
+ Register a serving container
+ """
+ now = datetime.now(timezone.utc)
+ key = self._get_container_key(company_id, request.container_id)
+ entry = ServingContainerEntry(
+ **request.to_struct(),
+ key=key,
+ company_id=company_id,
+ ip=ip,
+ register_time=now,
+ register_timeout=request.timeout,
+ last_activity_time=now,
+ )
+ self._save_serving_container_entry(entry)
+ return entry
+
+ def unregister_serving_container(
+ self,
+ company_id: str,
+ container_id: str,
+ ) -> None:
+ """
+ Unregister a serving container
+ """
+ entry = self._get_serving_container_entry(company_id, container_id)
+ if entry:
+ url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
+ self.redis.zrem(url_key, entry.key)
+
+ key = self._get_container_key(company_id, container_id)
+ res = self.redis.delete(key)
+ if res:
+ return
+
+ if not self.conf.get("container_auto_unregister", True):
+ raise errors.bad_request.ContainerNotRegistered(container=container_id)
+
+ def container_status_report(
+ self,
+ company_id: str,
+ report: StatusReportRequest,
+ ip: str = "",
+ ) -> None:
+ """
+ Serving container status report
+ """
+ container_id = report.container_id
+ now = datetime.now(timezone.utc)
+ entry = self._get_serving_container_entry(company_id, container_id)
+ if entry:
+ ip = ip or entry.ip
+ register_time = entry.register_time
+ register_timeout = entry.register_timeout
+ else:
+ if not self.conf.get("container_auto_register", True):
+ raise errors.bad_request.ContainerNotRegistered(container=container_id)
+ ip = ip
+ register_time = now
+ register_timeout = int(
+ self.conf.get("default_container_timeout_sec", 10 * 60)
+ )
+
+ key = self._get_container_key(company_id, container_id)
+ entry = ServingContainerEntry(
+ **report.to_struct(),
+ key=key,
+ company_id=company_id,
+ ip=ip,
+ register_time=register_time,
+ register_timeout=register_timeout,
+ last_activity_time=now,
+ )
+ self._save_serving_container_entry(entry)
+ ServingStats.log_stats_to_es(entry)
+
+ def _get_all(
+ self,
+ company_id: str,
+ ) -> Sequence[ServingContainerEntry]:
+ keys = list(self.redis.scan_iter(self._get_container_key(company_id, "*")))
+ entries = []
+ for keys in chunked_iter(keys, 1000):
+ data = self.redis.mget(keys)
+ if not data:
+ continue
+ for d in data:
+ try:
+ entries.append(ServingContainerEntry.from_json(d))
+ except Exception as ex:
+ log.error(f"Failed parsing container entry {str(ex)}")
+
+ return entries
+
+ @attr.s(auto_attribs=True)
+ class Counter:
+ class AggType(Enum):
+ avg = auto()
+ max = auto()
+ total = auto()
+ count = auto()
+
+ name: str
+ field: str
+ agg_type: AggType
+ float_precision: int = None
+
+ _max: Union[int, float, datetime] = attr.field(init=False, default=None)
+ _total: Union[int, float] = attr.field(init=False, default=0)
+ _count: int = attr.field(init=False, default=0)
+
+ def add(self, entry: ServingContainerEntry):
+ value = getattr(entry, self.field, None)
+ if value is None:
+ return
+
+ self._count += 1
+ if self.agg_type == self.AggType.max:
+ self._max = value if self._max is None else max(self._max, value)
+ else:
+ self._total += value
+
+ def __call__(self):
+ if self.agg_type == self.AggType.count:
+ return self._count
+
+ if self.agg_type == self.AggType.max:
+ return self._max
+
+ if self.agg_type == self.AggType.total:
+ return self._total
+
+ if not self._count:
+ return None
+ avg = self._total / self._count
+ return (
+ round(avg, self.float_precision) if self.float_precision else round(avg)
+ )
+
+ def _get_summary(self, entries: Sequence[ServingContainerEntry]) -> dict:
+ counters = [
+ self.Counter(
+ name="uptime_sec",
+ field="uptime_sec",
+ agg_type=self.Counter.AggType.max,
+ ),
+ self.Counter(
+ name="requests",
+ field="requests_num",
+ agg_type=self.Counter.AggType.total,
+ ),
+ self.Counter(
+ name="requests_min",
+ field="requests_min",
+ agg_type=self.Counter.AggType.avg,
+ float_precision=2,
+ ),
+ self.Counter(
+ name="latency_ms",
+ field="latency_ms",
+ agg_type=self.Counter.AggType.avg,
+ ),
+ self.Counter(
+ name="last_update",
+ field="last_activity_time",
+ agg_type=self.Counter.AggType.max,
+ ),
+ ]
+ for entry in entries:
+ for counter in counters:
+ counter.add(entry)
+
+ first_entry = entries[0]
+ ret = {
+ "endpoint": first_entry.endpoint_name,
+ "model": first_entry.model_name,
+ "url": first_entry.endpoint_url,
+ "instances": len(entries),
+ **{counter.name: counter() for counter in counters},
+ }
+ ret["last_update"] = ret.get("last_update")
+ return ret
+
+ def get_endpoints(self, company_id: str):
+ """
+ Group instances by urls and return a summary for each url
+ Do not return data for "loading" instances that have no url
+ """
+ entries = self._get_all(company_id)
+ by_url = bucketize(entries, key=attrgetter("endpoint_url"))
+ by_url.pop(None, None)
+ return [self._get_summary(url_entries) for url_entries in by_url.values()]
+
+ def _get_endpoint_entries(
+ self, company_id, endpoint_url: Union[str, None]
+ ) -> Sequence[ServingContainerEntry]:
+ url_key = self._get_url_key(company_id, endpoint_url)
+ timestamp = int(time())
+ self.redis.zremrangebyscore(url_key, min=0, max=timestamp)
+ container_keys = {key.decode() for key in self.redis.zrange(url_key, 0, -1)}
+ if not container_keys:
+ return []
+
+ entries = []
+ found_keys = set()
+ data = self.redis.mget(container_keys) or []
+ for d in data:
+ try:
+ entry = ServingContainerEntry.from_json(d)
+ if entry.endpoint_url == endpoint_url:
+ entries.append(entry)
+ found_keys.add(entry.key)
+ except Exception as ex:
+ log.error(f"Failed parsing container entry {str(ex)}")
+
+ missing_keys = container_keys - found_keys
+ if missing_keys:
+ self.redis.zrem(url_key, *missing_keys)
+
+ return entries
+
+ def get_loading_instances(self, company_id: str):
+ entries = self._get_endpoint_entries(company_id, None)
+ return [
+ {
+ "id": entry.container_id,
+ "endpoint": entry.endpoint_name,
+ "url": entry.endpoint_url,
+ "model": entry.model_name,
+ "model_source": entry.model_source,
+ "model_version": entry.model_version,
+ "preprocess_artifact": entry.preprocess_artifact,
+ "input_type": entry.input_type,
+ "input_size": entry.input_size,
+ "uptime_sec": entry.uptime_sec,
+ "age_sec": int((datetime.now(timezone.utc) - entry.register_time).total_seconds()),
+ "last_update": entry.last_activity_time,
+ }
+ for entry in entries
+ ]
+
+ def get_endpoint_details(self, company_id, endpoint_url: str) -> dict:
+ entries = self._get_endpoint_entries(company_id, endpoint_url)
+ if not entries:
+ raise errors.bad_request.NoContainersForUrl(url=endpoint_url)
+
+ instances = []
+ entry: ServingContainerEntry
+ for entry in entries:
+ instances.append(
+ {
+ "endpoint": entry.endpoint_name,
+ "model": entry.model_name,
+ "url": entry.endpoint_url,
+ }
+ )
+
+ def get_machine_stats_data(machine_stats: MachineStats) -> dict:
+ ret = {"cpu_count": 0, "gpu_count": 0}
+ if not machine_stats:
+ return ret
+
+ for value, field in (
+ (machine_stats.cpu_usage, "cpu_count"),
+ (machine_stats.gpu_usage, "gpu_count"),
+ ):
+ if value is None:
+ continue
+ ret[field] = len(value) if isinstance(value, (list, tuple)) else 1
+
+ return ret
+
+ first_entry = entries[0]
+ return {
+ "endpoint": first_entry.endpoint_name,
+ "model": first_entry.model_name,
+ "url": first_entry.endpoint_url,
+ "preprocess_artifact": first_entry.preprocess_artifact,
+ "input_type": first_entry.input_type,
+ "input_size": first_entry.input_size,
+ "model_source": first_entry.model_source,
+ "model_version": first_entry.model_version,
+ "uptime_sec": max(e.uptime_sec for e in entries),
+ "last_update": max(e.last_activity_time for e in entries),
+ "instances": [
+ {
+ "id": entry.container_id,
+ "uptime_sec": entry.uptime_sec,
+ "requests": entry.requests_num,
+ "requests_min": entry.requests_min,
+ "latency_ms": entry.latency_ms,
+ "last_update": entry.last_activity_time,
+ "reference": [ref.to_struct() for ref in entry.reference]
+ if isinstance(entry.reference, list)
+ else entry.reference,
+ **get_machine_stats_data(entry.machine_stats),
+ }
+ for entry in entries
+ ],
+ }
diff --git a/apiserver/bll/serving/stats.py b/apiserver/bll/serving/stats.py
new file mode 100644
index 0000000..734d7f8
--- /dev/null
+++ b/apiserver/bll/serving/stats.py
@@ -0,0 +1,340 @@
+from collections import defaultdict
+from datetime import datetime, timezone
+from enum import Enum
+
+from typing import Tuple, Optional, Sequence
+
+from elasticsearch import Elasticsearch
+
+from apiserver.apimodels.serving import (
+ ServingContainerEntry,
+ GetEndpointMetricsHistoryRequest,
+ MetricType,
+)
+from apiserver.apierrors import errors
+from apiserver.utilities.dicts import nested_get
+from apiserver.bll.query import Builder as QueryBuilder
+from apiserver.config_repo import config
+from apiserver.es_factory import es_factory
+
+
+class _AggregationType(Enum):
+ avg = "avg"
+ sum = "sum"
+
+
+class ServingStats:
+ min_chart_interval = config.get("services.serving.min_chart_interval_sec", 40)
+ es: Elasticsearch = es_factory.connect("workers")
+
+ @classmethod
+ def _serving_stats_prefix(cls, company_id: str) -> str:
+ """Returns the es index prefix for the company"""
+ return f"serving_stats_{company_id.lower()}_"
+
+ @staticmethod
+ def _get_es_index_suffix():
+ """Get the index name suffix for storing current month data"""
+ return datetime.now(timezone.utc).strftime("%Y-%m")
+
+ @staticmethod
+ def _get_average_value(value) -> Tuple[Optional[float], Optional[int]]:
+ if value is None:
+ return None, None
+
+ if isinstance(value, (list, tuple)):
+ count = len(value)
+ if not count:
+ return None, None
+ return sum(value) / count, count
+
+ return value, 1
+
+ @classmethod
+ def log_stats_to_es(
+ cls,
+ entry: ServingContainerEntry,
+ ) -> int:
+ """
+ Actually writing the worker statistics to Elastic
+ :return: The amount of logged documents
+ """
+ company_id = entry.company_id
+ es_index = (
+ f"{cls._serving_stats_prefix(company_id)}" f"{cls._get_es_index_suffix()}"
+ )
+
+ entry_data = entry.to_struct()
+ doc = {
+ "timestamp": es_factory.get_timestamp_millis(),
+ **{
+ field: entry_data.get(field)
+ for field in (
+ "container_id",
+ "company_id",
+ "endpoint_url",
+ "requests_num",
+ "requests_min",
+ "uptime_sec",
+ "latency_ms",
+ )
+ },
+ }
+
+ stats = entry_data.get("machine_stats")
+ if stats:
+ for category in ("cpu", "gpu"):
+ usage, num = cls._get_average_value(stats.get(f"{category}_usage"))
+ doc.update({f"{category}_usage": usage, f"{category}_num": num})
+
+ for category in ("memory", "gpu_memory"):
+ free, _ = cls._get_average_value(stats.get(f"{category}_free"))
+ used, _ = cls._get_average_value(stats.get(f"{category}_used"))
+ doc.update(
+ {
+ f"{category}_free": free,
+ f"{category}_used": used,
+ f"{category}_total": round((free or 0) + (used or 0), 3),
+ }
+ )
+
+ doc.update(
+ {
+ field: stats.get(field)
+ for field in ("disk_free_home", "network_rx", "network_tx")
+ }
+ )
+
+ cls.es.index(index=es_index, document=doc)
+
+ return 1
+
+ @staticmethod
+ def round_series(values: Sequence, koeff) -> list:
+ return [round(v * koeff, 2) if v else 0 for v in values]
+
+ _mb_to_gb = 1 / 1024
+ agg_fields = {
+ MetricType.requests: (
+ "requests_num",
+ "Number of Requests",
+ _AggregationType.sum,
+ None,
+ ),
+ MetricType.requests_min: (
+ "requests_min",
+ "Requests per Minute",
+ _AggregationType.sum,
+ None,
+ ),
+ MetricType.latency_ms: (
+ "latency_ms",
+ "Average Latency (ms)",
+ _AggregationType.avg,
+ None,
+ ),
+ MetricType.cpu_count: ("cpu_num", "CPU Count", _AggregationType.sum, None),
+ MetricType.gpu_count: ("gpu_num", "GPU Count", _AggregationType.sum, None),
+ MetricType.cpu_util: (
+ "cpu_usage",
+ "Average CPU Load (%)",
+ _AggregationType.avg,
+ None,
+ ),
+ MetricType.gpu_util: (
+ "gpu_usage",
+ "Average GPU Utilization (%)",
+ _AggregationType.avg,
+ None,
+ ),
+ MetricType.ram_total: (
+ "memory_total",
+ "RAM Total (GB)",
+ _AggregationType.sum,
+ _mb_to_gb,
+ ),
+ MetricType.ram_used: (
+ "memory_used",
+ "RAM Used (GB)",
+ _AggregationType.sum,
+ _mb_to_gb,
+ ),
+ MetricType.ram_free: (
+ "memory_free",
+ "RAM Free (GB)",
+ _AggregationType.sum,
+ _mb_to_gb,
+ ),
+ MetricType.gpu_ram_total: (
+ "gpu_memory_total",
+ "GPU RAM Total (GB)",
+ _AggregationType.sum,
+ _mb_to_gb,
+ ),
+ MetricType.gpu_ram_used: (
+ "gpu_memory_used",
+ "GPU RAM Used (GB)",
+ _AggregationType.sum,
+ _mb_to_gb,
+ ),
+ MetricType.gpu_ram_free: (
+ "gpu_memory_free",
+ "GPU RAM Free (GB)",
+ _AggregationType.sum,
+ _mb_to_gb,
+ ),
+ MetricType.network_rx: (
+ "network_rx",
+ "Network Throughput RX (MBps)",
+ _AggregationType.sum,
+ None,
+ ),
+ MetricType.network_tx: (
+ "network_tx",
+ "Network Throughput TX (MBps)",
+ _AggregationType.sum,
+ None,
+ ),
+ }
+
+ @classmethod
+ def get_endpoint_metrics(
+ cls,
+ company_id: str,
+ metrics_request: GetEndpointMetricsHistoryRequest,
+ ) -> dict:
+ from_date = metrics_request.from_date
+ to_date = metrics_request.to_date
+ if from_date >= to_date:
+ raise errors.bad_request.FieldsValueError(
+ "from_date must be less than to_date"
+ )
+
+ metric_type = metrics_request.metric_type
+ agg_data = cls.agg_fields.get(metric_type)
+ if not agg_data:
+ raise NotImplemented(f"Charts for {metric_type} not implemented")
+
+ agg_field, title, agg_type, multiplier = agg_data
+ if agg_type == _AggregationType.sum:
+ instance_sum_type = "sum_bucket"
+ else:
+ instance_sum_type = "avg_bucket"
+
+ interval = max(metrics_request.interval, cls.min_chart_interval)
+ endpoint_url = metrics_request.endpoint_url
+ hist_ret = {
+ "computed_interval": interval,
+ "total": {
+ "title": title,
+ "dates": [],
+ "values": [],
+ },
+ "instances": {},
+ }
+ must_conditions = [
+ QueryBuilder.term("company_id", company_id),
+ QueryBuilder.term("endpoint_url", endpoint_url),
+ QueryBuilder.dates_range(from_date, to_date),
+ ]
+ query = {"bool": {"must": must_conditions}}
+ es_index = f"{cls._serving_stats_prefix(company_id)}*"
+ res = cls.es.search(
+ index=es_index,
+ size=0,
+ query=query,
+ aggs={"instances": {"terms": {"field": "container_id"}}},
+ )
+ instance_buckets = nested_get(res, ("aggregations", "instances", "buckets"))
+ if not instance_buckets:
+ return hist_ret
+
+ instance_keys = {ib["key"] for ib in instance_buckets}
+ must_conditions.append(QueryBuilder.terms("container_id", instance_keys))
+ query = {"bool": {"must": must_conditions}}
+ sample_func = "avg" if metric_type != MetricType.requests else "max"
+ aggs = {
+ "instances": {
+ "terms": {
+ "field": "container_id",
+ "size": max(len(instance_keys), 10),
+ },
+ "aggs": {
+ "sample": {sample_func: {"field": agg_field}},
+ },
+ },
+ "total_instances": {
+ instance_sum_type: {
+ "gap_policy": "insert_zeros",
+ "buckets_path": "instances>sample",
+ }
+ },
+ }
+ hist_params = {}
+ if metric_type == MetricType.requests:
+ hist_params["min_doc_count"] = 1
+ else:
+ hist_params["extended_bounds"] = {
+ "min": int(from_date) * 1000,
+ "max": int(to_date) * 1000,
+ }
+ aggs = {
+ "dates": {
+ "date_histogram": {
+ "field": "timestamp",
+ "fixed_interval": f"{interval}s",
+ **hist_params,
+ },
+ "aggs": aggs,
+ }
+ }
+
+ filter_path = None
+ if not metrics_request.instance_charts:
+ filter_path = "aggregations.dates.buckets.total_instances"
+
+ data = cls.es.search(
+ index=es_index,
+ size=0,
+ query=query,
+ aggs=aggs,
+ filter_path=filter_path,
+ )
+ agg_res = data.get("aggregations")
+ if not agg_res:
+ return hist_ret
+
+ dates_ = []
+ total = []
+ instances = defaultdict(list)
+ # remove last interval if it's incomplete. Allow 10% tolerance
+ last_valid_timestamp = (to_date - 0.9 * interval) * 1000
+ for point in agg_res["dates"]["buckets"]:
+ date_ = point["key"]
+ if date_ > last_valid_timestamp:
+ break
+ dates_.append(date_)
+ total.append(nested_get(point, ("total_instances", "value"), 0))
+ if metrics_request.instance_charts:
+ found_keys = set()
+ for instance in nested_get(point, ("instances", "buckets"), []):
+ instances[instance["key"]].append(
+ nested_get(instance, ("sample", "value"), 0)
+ )
+ found_keys.add(instance["key"])
+ for missing_key in instance_keys - found_keys:
+ instances[missing_key].append(0)
+
+ koeff = multiplier if multiplier else 1.0
+ hist_ret["total"]["dates"] = dates_
+ hist_ret["total"]["values"] = cls.round_series(total, koeff)
+ hist_ret["instances"] = {
+ key: {
+ "title": key,
+ "dates": dates_,
+ "values": cls.round_series(values, koeff),
+ }
+ for key, values in sorted(instances.items(), key=lambda p: p[0])
+ }
+
+ return hist_ret
diff --git a/apiserver/bll/storage/__init__.py b/apiserver/bll/storage/__init__.py
index 7c88689..0957700 100644
--- a/apiserver/bll/storage/__init__.py
+++ b/apiserver/bll/storage/__init__.py
@@ -1,14 +1,32 @@
+import json
+import os
+import tempfile
from copy import copy
+from datetime import datetime
+from typing import Optional, Sequence
+import attr
from boltons.cacheutils import cachedproperty
from clearml.backend_config.bucket_config import (
S3BucketConfigurations,
AzureContainerConfigurations,
GSBucketConfigurations,
+ AzureContainerConfig,
+ GSBucketConfig,
+ S3BucketConfig,
)
+from apiserver.apierrors import errors
+from apiserver.apimodels.storage import SetSettingsRequest
from apiserver.config_repo import config
-
+from apiserver.database.model.storage_settings import (
+ StorageSettings,
+ GoogleBucketSettings,
+ AWSSettings,
+ AzureStorageSettings,
+ GoogleStorageSettings,
+)
+from apiserver.database.utils import id as db_id
log = config.logger(__file__)
@@ -32,17 +50,224 @@ class StorageBLL:
def get_azure_settings_for_company(
self,
company_id: str,
+ db_settings: StorageSettings = None,
+ query_db: bool = True,
) -> AzureContainerConfigurations:
- return copy(self._default_azure_configs)
+ if not db_settings and query_db:
+ db_settings = (
+ StorageSettings.objects(company=company_id).only("azure").first()
+ )
+
+ if not db_settings or not db_settings.azure:
+ return copy(self._default_azure_configs)
+
+ azure = db_settings.azure
+ return AzureContainerConfigurations(
+ container_configs=[
+ AzureContainerConfig(**entry.to_proper_dict())
+ for entry in (azure.containers or [])
+ ]
+ )
def get_gs_settings_for_company(
self,
company_id: str,
+ db_settings: StorageSettings = None,
+ query_db: bool = True,
+ json_string: bool = False,
) -> GSBucketConfigurations:
- return copy(self._default_gs_configs)
+ if not db_settings and query_db:
+ db_settings = (
+ StorageSettings.objects(company=company_id).only("google").first()
+ )
+
+ if not db_settings or not db_settings.google:
+ if not json_string:
+ return copy(self._default_gs_configs)
+
+ if self._default_gs_configs._buckets:
+ buckets = [
+ attr.evolve(
+ b,
+ credentials_json=self._assure_json_string(b.credentials_json),
+ )
+ for b in self._default_gs_configs._buckets
+ ]
+ else:
+ buckets = self._default_gs_configs._buckets
+
+ return GSBucketConfigurations(
+ buckets=buckets,
+ default_project=self._default_gs_configs._default_project,
+ default_credentials=self._assure_json_string(
+ self._default_gs_configs._default_credentials
+ ),
+ )
+
+ def get_bucket_config(bc: GoogleBucketSettings) -> GSBucketConfig:
+ data = bc.to_proper_dict()
+ if not json_string and bc.credentials_json:
+ data["credentials_json"] = self._assure_json_file(bc.credentials_json)
+ return GSBucketConfig(**data)
+
+ google = db_settings.google
+ buckets_configs = [get_bucket_config(b) for b in (google.buckets or [])]
+ return GSBucketConfigurations(
+ buckets=buckets_configs,
+ default_project=google.project,
+ default_credentials=google.credentials_json
+ if json_string
+ else self._assure_json_file(google.credentials_json),
+ )
def get_aws_settings_for_company(
self,
company_id: str,
+ db_settings: StorageSettings = None,
+ query_db: bool = True,
) -> S3BucketConfigurations:
- return copy(self._default_aws_configs)
+ if not db_settings and query_db:
+ db_settings = (
+ StorageSettings.objects(company=company_id).only("aws").first()
+ )
+ if not db_settings or not db_settings.aws:
+ return copy(self._default_aws_configs)
+
+ aws = db_settings.aws
+ buckets_configs = S3BucketConfig.from_list(
+ [b.to_proper_dict() for b in (aws.buckets or [])]
+ )
+ return S3BucketConfigurations(
+ buckets=buckets_configs,
+ default_key=aws.key,
+ default_secret=aws.secret,
+ default_region=aws.region,
+ default_use_credentials_chain=aws.use_credentials_chain,
+ default_token=aws.token,
+ default_extra_args={},
+ )
+
+ def _assure_json_file(self, name_or_content: str) -> str:
+ if not name_or_content:
+ return name_or_content
+
+ if name_or_content.endswith(".json") or os.path.exists(name_or_content):
+ return name_or_content
+
+ try:
+ json.loads(name_or_content)
+ except Exception:
+ return name_or_content
+
+ with tempfile.NamedTemporaryFile(
+ mode="wt", delete=False, suffix=".json"
+ ) as tmp:
+ tmp.write(name_or_content)
+
+ return tmp.name
+
+ def _assure_json_string(self, name_or_content: str) -> Optional[str]:
+ if not name_or_content:
+ return name_or_content
+
+ try:
+ json.loads(name_or_content)
+ return name_or_content
+ except Exception:
+ pass
+
+ try:
+ with open(name_or_content) as fp:
+ return fp.read()
+ except Exception:
+ return None
+
+ def get_company_settings(self, company_id: str) -> dict:
+ db_settings = StorageSettings.objects(company=company_id).first()
+ aws = self.get_aws_settings_for_company(company_id, db_settings, query_db=False)
+ aws_dict = {
+ "key": aws._default_key,
+ "secret": aws._default_secret,
+ "token": aws._default_token,
+ "region": aws._default_region,
+ "use_credentials_chain": aws._default_use_credentials_chain,
+ "buckets": [attr.asdict(b) for b in aws._buckets],
+ }
+
+ gs = self.get_gs_settings_for_company(
+ company_id, db_settings, query_db=False, json_string=True
+ )
+ gs_dict = {
+ "project": gs._default_project,
+ "credentials_json": gs._default_credentials,
+ "buckets": [attr.asdict(b) for b in gs._buckets],
+ }
+
+ azure = self.get_azure_settings_for_company(company_id, db_settings)
+ azure_dict = {
+ "containers": [attr.asdict(ac) for ac in azure._container_configs],
+ }
+
+ return {
+ "aws": aws_dict,
+ "google": gs_dict,
+ "azure": azure_dict,
+ "last_update": db_settings.last_update if db_settings else None,
+ }
+
+ def set_company_settings(
+ self, company_id: str, settings: SetSettingsRequest
+ ) -> int:
+ update_dict = {}
+ if settings.aws:
+ update_dict["aws"] = {
+ **{
+ k: v
+ for k, v in settings.aws.to_struct().items()
+ if k in AWSSettings.get_fields()
+ }
+ }
+
+ if settings.azure:
+ update_dict["azure"] = {
+ **{
+ k: v
+ for k, v in settings.azure.to_struct().items()
+ if k in AzureStorageSettings.get_fields()
+ }
+ }
+
+ if settings.google:
+ update_dict["google"] = {
+ **{
+ k: v
+ for k, v in settings.google.to_struct().items()
+ if k in GoogleStorageSettings.get_fields()
+ }
+ }
+ cred_json = update_dict["google"].get("credentials_json")
+ if cred_json:
+ try:
+ json.loads(cred_json)
+ except Exception as ex:
+ raise errors.bad_request.ValidationError(
+ f"Invalid json credentials: {str(ex)}"
+ )
+
+ if not update_dict:
+ raise errors.bad_request.ValidationError("No settings were provided")
+
+ settings = StorageSettings.objects(company=company_id).only("id").first()
+ settings_id = settings.id if settings else db_id()
+ return StorageSettings.objects(id=settings_id).update(
+ upsert=True,
+ id=settings_id,
+ company=company_id,
+ last_update=datetime.utcnow(),
+ **update_dict,
+ )
+
+ def reset_company_settings(self, company_id: str, keys: Sequence[str]) -> int:
+ return StorageSettings.objects(company=company_id).update(
+ last_update=datetime.utcnow(), **{f"unset__{k}": 1 for k in keys}
+ )
diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py
index eae25b2..de7b71e 100644
--- a/apiserver/bll/task/hyperparams.py
+++ b/apiserver/bll/task/hyperparams.py
@@ -193,7 +193,7 @@ class HyperParams:
pipeline = [
{
"$match": {
- "company": {"$in": [None, "", company_id]},
+ "company": {"$in": ["", company_id]},
"_id": {"$in": task_ids},
}
},
diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py
index 6c99b23..cdce8af 100644
--- a/apiserver/bll/task/task_bll.py
+++ b/apiserver/bll/task/task_bll.py
@@ -39,6 +39,7 @@ from apiserver.database.utils import (
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
+from apiserver.utilities.dicts import nested_set
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@@ -163,18 +164,36 @@ class TaskBLL:
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
+ hyperparams_overrides: Optional[dict] = None,
+ configuration_overrides: Optional[dict] = None,
) -> Tuple[Task, dict]:
validate_tags(tags, system_tags)
- params_dict = {
- field: value
- for field, value in (
- ("hyperparams", hyperparams),
- ("configuration", configuration),
- )
- if value is not None
- }
+ task: Task = cls.get_by_id(
+ company_id=company_id, task_id=task_id, allow_public=True
+ )
- task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
+ params_dict = {}
+ if hyperparams:
+ params_dict["hyperparams"] = hyperparams
+ elif hyperparams_overrides:
+ updated_hyperparams = {
+ sec: {k: value for k, value in sec_data.items()}
+ for sec, sec_data in (task.hyperparams or {}).items()
+ }
+ for section, section_data in hyperparams_overrides.items():
+ for key, value in section_data.items():
+ nested_set(updated_hyperparams, (section, key), value)
+ params_dict["hyperparams"] = updated_hyperparams
+
+ if configuration:
+ params_dict["configuration"] = configuration
+ elif configuration_overrides:
+ updated_configuration = {
+ k: value for k, value in (task.configuration or {}).items()
+ }
+ for key, value in configuration_overrides.items():
+ updated_configuration[key] = value
+ params_dict["configuration"] = updated_configuration
now = datetime.utcnow()
if input_models:
@@ -389,7 +408,7 @@ class TaskBLL:
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
- :param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
+ :param last_scalar_events: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:
@@ -439,8 +458,13 @@ class TaskBLL:
return ret
@staticmethod
- def remove_task_from_all_queues(company_id: str, task_id: str) -> int:
- return Queue.objects(company=company_id, entries__task=task_id).update(
+ def remove_task_from_all_queues(
+ company_id: str, task_id: str, exclude: str = None
+ ) -> int:
+ more = {}
+ if exclude:
+ more["id__ne"] = exclude
+ return Queue.objects(company=company_id, entries__task=task_id, **more).update(
pull__entries__task=task_id, last_update=datetime.utcnow()
)
@@ -454,9 +478,10 @@ class TaskBLL:
status_reason: str,
remove_from_all_queues=False,
new_status=None,
+ new_status_for_aborted_task=None,
):
try:
- cls.dequeue(task, company_id, silent_fail=True)
+ cls.dequeue(task, company_id=company_id, user_id=user_id, silent_fail=True)
except APIError:
# dequeue may fail if the queue was deleted
pass
@@ -467,6 +492,9 @@ class TaskBLL:
if task.status not in [TaskStatus.queued, TaskStatus.in_progress]:
return {"updated": 0}
+ if new_status_for_aborted_task and task.status == TaskStatus.in_progress:
+ new_status = new_status_for_aborted_task
+
return ChangeStatusRequest(
task=task,
new_status=new_status or task.enqueue_status or TaskStatus.created,
@@ -477,7 +505,7 @@ class TaskBLL:
).execute(enqueue_status=None)
@classmethod
- def dequeue(cls, task: Task, company_id: str, silent_fail=False):
+ def dequeue(cls, task: Task, company_id: str, user_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
@@ -504,6 +532,9 @@ class TaskBLL:
return {
"removed": queue_bll.remove_task(
- company_id=company_id, queue_id=task.execution.queue, task_id=task.id
+ company_id=company_id,
+ user_id=user_id,
+ queue_id=task.execution.queue,
+ task_id=task.id,
)
}
diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py
index 59b73b2..c9b91a8 100644
--- a/apiserver/bll/task/task_cleanup.py
+++ b/apiserver/bll/task/task_cleanup.py
@@ -31,8 +31,8 @@ event_bll = EventBLL()
@attr.s(auto_attribs=True)
class TaskUrls:
model_urls: Sequence[str]
- event_urls: Sequence[str]
artifact_urls: Sequence[str]
+ event_urls: Sequence[str] = [] # left here is in order not to break the api
def __add__(self, other: "TaskUrls"):
if not other:
@@ -40,7 +40,6 @@ class TaskUrls:
return TaskUrls(
model_urls=list(set(self.model_urls) | set(other.model_urls)),
- event_urls=list(set(self.event_urls) | set(other.event_urls)),
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
)
@@ -54,8 +53,23 @@ class CleanupResult:
updated_children: int
updated_models: int
deleted_models: int
+ deleted_model_ids: Set[str]
urls: TaskUrls = None
+ def to_res_dict(self, return_file_urls: bool) -> dict:
+ remove_fields = ["deleted_model_ids"]
+ if not return_file_urls:
+ remove_fields.append("urls")
+
+ # noinspection PyTypeChecker
+ res = attr.asdict(
+ self, filter=lambda attrib, value: attrib.name not in remove_fields
+ )
+ if not return_file_urls:
+ res["urls"] = None
+
+ return res
+
def __add__(self, other: "CleanupResult"):
if not other:
return self
@@ -65,6 +79,16 @@ class CleanupResult:
updated_models=self.updated_models + other.updated_models,
deleted_models=self.deleted_models + other.deleted_models,
urls=self.urls + other.urls if self.urls else other.urls,
+ deleted_model_ids=self.deleted_model_ids | other.deleted_model_ids,
+ )
+
+ @staticmethod
+ def empty():
+ return CleanupResult(
+ updated_children=0,
+ updated_models=0,
+ deleted_models=0,
+ deleted_model_ids=set(),
)
@@ -130,7 +154,7 @@ supported_storage_types.update(
)
-def _schedule_for_delete(
+def schedule_for_delete(
company: str,
user: str,
task_id: str,
@@ -197,15 +221,27 @@ def _schedule_for_delete(
return processed_urls
+def delete_task_events_and_collect_urls(
+ company: str, task_ids: Sequence[str], wait_for_delete: bool, model=False
+) -> Set[str]:
+ event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
+ company, task_ids
+ )
+
+ event_bll.delete_task_events(
+ company, task_ids, model=model, wait_for_delete=wait_for_delete
+ )
+
+ return event_urls
+
+
def cleanup_task(
company: str,
user: str,
task: Task,
force: bool = False,
update_children=True,
- return_file_urls=False,
delete_output_models=True,
- delete_external_artifacts=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
@@ -216,22 +252,16 @@ def cleanup_task(
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
task, force
)
- delete_external_artifacts = delete_external_artifacts and config.get(
- "services.async_urls_delete.enabled", True
- )
- event_urls, artifact_urls, model_urls = set(), set(), set()
- if return_file_urls or delete_external_artifacts:
- event_urls = collect_debug_image_urls(task.company, task.id)
- event_urls.update(collect_plot_image_urls(task.company, task.id))
- if task.execution and task.execution.artifacts:
- artifact_urls = {
- a.uri
- for a in task.execution.artifacts.values()
- if a.mode == ArtifactModes.output and a.uri
- }
- model_urls = {
- m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
+ artifact_urls = (
+ {
+ a.uri
+ for a in task.execution.artifacts.values()
+ if a.mode == ArtifactModes.output and a.uri
}
+ if task.execution and task.execution.artifacts
+ else {}
+ )
+ model_urls = {m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids}
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
@@ -245,22 +275,15 @@ def cleanup_task(
deleted_models = 0
updated_models = 0
+ deleted_model_ids = set()
for models, allow_delete in ((draft_models, True), (published_models, False)):
if not models:
continue
if delete_output_models and allow_delete:
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
if model_ids:
- if return_file_urls or delete_external_artifacts:
- event_urls.update(collect_debug_image_urls(task.company, model_ids))
- event_urls.update(collect_plot_image_urls(task.company, model_ids))
-
- event_bll.delete_multi_task_events(
- task.company,
- model_ids,
- model=True,
- )
deleted_models += Model.objects(id__in=model_ids).delete()
+ deleted_model_ids.update(model_ids)
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(
@@ -283,30 +306,15 @@ def cleanup_task(
set__last_changed_by=user,
)
- event_bll.delete_task_events(task.company, task.id, allow_locked=force)
-
- if delete_external_artifacts:
- scheduled = _schedule_for_delete(
- task_id=task.id,
- company=company,
- user=user,
- urls=event_urls | model_urls | artifact_urls,
- can_delete_folders=not in_use_model_ids and not published_models,
- )
- for urls in (event_urls, model_urls, artifact_urls):
- urls.difference_update(scheduled)
-
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
urls=TaskUrls(
- event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
model_urls=list(model_urls),
- )
- if return_file_urls
- else None,
+ ),
+ deleted_model_ids=deleted_model_ids,
)
diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py
index 75d93e3..7df9a79 100644
--- a/apiserver/bll/task/task_operations.py
+++ b/apiserver/bll/task/task_operations.py
@@ -22,7 +22,8 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
ArtifactModes,
Execution,
- DEFAULT_LAST_ITERATION, TaskType,
+ DEFAULT_LAST_ITERATION,
+ TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
@@ -85,6 +86,7 @@ def archive_task(
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
+ new_status_for_aborted_task=TaskStatus.stopped,
)
except APIError:
# dequeue may fail if the task was not enqueued
@@ -99,7 +101,9 @@ def archive_task(
)
if include_pipeline_steps and (
- step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
+ step_tasks := _get_pipeline_steps_for_controller_task(
+ task, company_id, only=fields
+ )
):
for step in step_tasks:
archive_task_core(step)
@@ -136,7 +140,9 @@ def unarchive_task(
)
if include_pipeline_steps and (
- step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
+ step_tasks := _get_pipeline_steps_for_controller_task(
+ task, company_id, only=fields
+ )
):
for step in step_tasks:
unarchive_task_core(step)
@@ -204,12 +210,25 @@ def enqueue_task(
queue_name: str = None,
validate: bool = False,
force: bool = False,
+ update_execution_queue: bool = True,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
+ task = get_task_with_write_access(
+ task_id=task_id, company_id=company_id, identity=identity
+ )
+ if not update_execution_queue:
+ if not (
+ task.status == TaskStatus.queued and task.execution and task.execution.queue
+ ):
+ raise errors.bad_request.ValidationError(
+ "Cannot skip setting execution queue for a task "
+ "that is not enqueued or does not have execution queue set"
+ )
+
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
@@ -222,23 +241,21 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
- task = get_task_with_write_access(
- task_id=task_id, company_id=company_id, identity=identity
- )
-
user_id = identity.user
if validate:
TaskBLL.validate(task)
+ before_enqueue_status = task.status
+ if task.status == TaskStatus.queued and task.enqueue_status:
+ before_enqueue_status = task.enqueue_status
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
- allow_same_state_transition=False,
force=force,
user_id=user_id,
- ).execute(enqueue_status=task.status)
+ ).execute(enqueue_status=before_enqueue_status)
try:
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
@@ -255,12 +272,19 @@ def enqueue_task(
raise
# set the current queue ID in the task
- if task.execution:
- Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
- else:
- Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
+ if update_execution_queue:
+ if task.execution:
+ Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
+ else:
+ Task.objects(id=task_id).update(
+ execution=Execution(queue=queue_id), multi=False
+ )
+ nested_set(res, ("fields", "execution.queue"), queue_id)
- nested_set(res, ("fields", "execution.queue"), queue_id)
+ # make sure that the task is not queued in any other queue
+ TaskBLL.remove_task_from_all_queues(
+ company_id=company_id, task_id=task_id, exclude=queue_id
+ )
return 1, res
@@ -294,17 +318,13 @@ def delete_task(
identity: Identity,
move_to_trash: bool,
force: bool,
- return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
- delete_external_artifacts: bool,
include_pipeline_steps: bool,
) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user
- task = get_task_with_write_access(
- task_id, company_id=company_id, identity=identity
- )
+ task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if (
task.status != TaskStatus.created
@@ -318,7 +338,7 @@ def delete_task(
current=task.status,
)
- def delete_task_core(task_: Task, force_: bool):
+ def delete_task_core(task_: Task, force_: bool) -> CleanupResult:
try:
TaskBLL.dequeue_and_change_status(
task_,
@@ -337,9 +357,7 @@ def delete_task(
user=user_id,
task=task_,
force=force_,
- return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
- delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
@@ -353,11 +371,12 @@ def delete_task(
return res
task_ids = [task.id]
+ cleanup_res = CleanupResult.empty()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
- delete_task_core(step, True)
+ cleanup_res += delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
@@ -373,15 +392,11 @@ def reset_task(
company_id: str,
identity: Identity,
force: bool,
- return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
- delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
user_id = identity.user
- task = get_task_with_write_access(
- task_id, company_id=company_id, identity=identity
- )
+ task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@@ -390,7 +405,9 @@ def reset_task(
updates = {}
try:
- dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
+ dequeued = TaskBLL.dequeue(
+ task, company_id=company_id, user_id=user_id, silent_fail=True
+ )
except APIError:
# dequeue may fail if the task was not enqueued
pass
@@ -403,9 +420,7 @@ def reset_task(
task=task,
force=force,
update_children=False,
- return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
- delete_external_artifacts=delete_external_artifacts,
)
updates.update(
@@ -466,9 +481,7 @@ def publish_task(
status_reason: str = "",
) -> dict:
user_id = identity.user
- task = get_task_with_write_access(
- task_id, company_id=company_id, identity=identity
- )
+ task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force:
validate_status_change(task.status, TaskStatus.published)
@@ -566,7 +579,9 @@ def stop_task(
if set_stopped:
if is_queued:
try:
- TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True)
+ TaskBLL.dequeue(
+ task_, company_id=company_id, user_id=user_id, silent_fail=True
+ )
except APIError:
# dequeue may fail if the task was not enqueued
pass
@@ -587,7 +602,9 @@ def stop_task(
).execute()
if include_pipeline_steps and (
- step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
+ step_tasks := _get_pipeline_steps_for_controller_task(
+ task, company_id, only=fields
+ )
):
for step in step_tasks:
stop_task_core(step, True)
diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py
index 5cb9627..3366c4c 100644
--- a/apiserver/bll/task/utils.py
+++ b/apiserver/bll/task/utils.py
@@ -144,7 +144,12 @@ state_machine = {
TaskStatus.publishing,
TaskStatus.stopped,
},
- TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
+ TaskStatus.failed: {
+ TaskStatus.created,
+ TaskStatus.stopped,
+ TaskStatus.published,
+ TaskStatus.queued,
+ },
TaskStatus.publishing: {TaskStatus.published},
TaskStatus.published: set(),
TaskStatus.completed: {
@@ -177,7 +182,7 @@ def get_many_tasks_for_writing(
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
- missing = [f for f in ("company", ) if f not in only]
+ missing = [f for f in ("company",) if f not in only]
if missing:
only = [*only, *missing]
@@ -230,7 +235,7 @@ def get_task_for_update(
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
- force: bool = False
+ force: bool = False,
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
@@ -286,13 +291,62 @@ def get_last_metric_updates(
new_metrics = []
+ def add_last_metric_mean_update(
+ metric_path: str,
+ metric_count: int,
+ metric_total: float,
+ ):
+ """
+ Update new mean field based on the value in db and new data
+ The count field is updated here too and not with inc__ so that
+ it will not get updated in the db earlier than the corresponding mean
+ """
+ metric_path = metric_path.replace("__", ".")
+ mean_value_field = f"{metric_path}.mean_value"
+ count_field = f"{metric_path}.count"
+ raw_updates[mean_value_field] = {
+ "$round": [
+ {
+ "$divide": [
+ {
+ "$add": [
+ {
+ "$multiply": [
+ {"$ifNull": [f"${mean_value_field}", 0]},
+ {"$ifNull": [f"${count_field}", 0]},
+ ]
+ },
+ metric_total,
+ ]
+ },
+ {
+ "$add": [
+ {"$ifNull": [f"${count_field}", 0]},
+ metric_count,
+ ]
+ },
+ ]
+ },
+ 2,
+ ]
+ }
+ raw_updates[count_field] = {
+ "$add": [
+ {"$ifNull": [f"${count_field}", 0]},
+ metric_count,
+ ]
+ }
+
def add_last_metric_conditional_update(
- metric_path: str, metric_value, iter_value: int, is_min: bool
+ metric_path: str, metric_value, iter_value: int, is_min: bool, is_first: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
- if is_min:
+ if is_first:
+ field_prefix = "first"
+ op = None
+ elif is_min:
field_prefix = "min"
op = "$gt"
else:
@@ -300,18 +354,23 @@ def get_last_metric_updates(
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
- condition = {
- "$or": [
- {"$lte": [f"${value_field}", None]},
- {op: [f"${value_field}", metric_value]},
- ]
- }
+ exists = {"$lte": [f"${value_field}", None]}
+ if op:
+ condition = {
+ "$or": [
+ exists,
+ {op: [f"${value_field}", metric_value]},
+ ]
+ }
+ else:
+ condition = exists
+
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
- value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
- "__", "."
+ value_iteration_field = (
+ f"{metric_path}__{field_prefix}_value_iteration".replace("__", ".")
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
@@ -328,15 +387,25 @@ def get_last_metric_updates(
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
- if key in ("min_value", "max_value"):
+ if key in ("min_value", "max_value", "first_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
+ is_first=(key == "first_value"),
)
- elif key in ("metric", "variant", "value"):
+ elif key in ("metric", "variant", "value", "x_axis_label"):
extra_updates[f"set__{path}__{key}"] = value
+ count = variant_data.get("count")
+ total = variant_data.get("total")
+ if count is not None and total is not None:
+ add_last_metric_mean_update(
+ metric_path=path,
+ metric_count=count,
+ metric_total=total,
+ )
+
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics
diff --git a/apiserver/bll/user/__init__.py b/apiserver/bll/user/__init__.py
index f47221b..46a7f85 100644
--- a/apiserver/bll/user/__init__.py
+++ b/apiserver/bll/user/__init__.py
@@ -2,6 +2,7 @@ from datetime import datetime
from apiserver.apierrors import errors
from apiserver.apimodels.users import CreateRequest
+from apiserver.config.info import get_version
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.user import User
@@ -14,7 +15,11 @@ class UserBLL:
if user_id and User.objects(id=user_id).only("id"):
raise errors.bad_request.UserIdExists(id=user_id)
- user = User(**request.to_struct(), created=datetime.utcnow())
+ user = User(
+ **request.to_struct(),
+ created=datetime.utcnow(),
+ created_in_version=get_version(),
+ )
user.save(force_insert=True)
@staticmethod
diff --git a/apiserver/bll/workers/__init__.py b/apiserver/bll/workers/__init__.py
index 1caec6c..4eec5a0 100644
--- a/apiserver/bll/workers/__init__.py
+++ b/apiserver/bll/workers/__init__.py
@@ -297,6 +297,7 @@ class WorkerBLL:
{
"$project": {
"name": 1,
+ "display_name": 1,
"next_entry": {"$arrayElemAt": ["$entries", 0]},
"num_entries": {"$size": "$entries"},
}
@@ -330,6 +331,7 @@ class WorkerBLL:
if not info:
continue
entry.name = info.get("name", None)
+ entry.display_name = info.get("display_name", None)
entry.num_tasks = info.get("num_entries", 0)
task_id = nested_get(info, ("next_entry", "task"))
if task_id:
diff --git a/apiserver/bll/workers/stats.py b/apiserver/bll/workers/stats.py
index fa3abb6..a1033a2 100644
--- a/apiserver/bll/workers/stats.py
+++ b/apiserver/bll/workers/stats.py
@@ -73,9 +73,13 @@ class WorkerStats:
Buckets with no metrics are not returned
Note: all the statistics are retrieved as one ES query
"""
- if request.from_date >= request.to_date:
+ from_date = request.from_date
+ to_date = request.to_date
+ if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
+ interval = max(request.interval, self.min_chart_interval)
+
def get_dates_agg() -> dict:
es_to_agg_types = (
("avg", AggregationType.avg.value),
@@ -87,8 +91,11 @@ class WorkerStats:
"dates": {
"date_histogram": {
"field": "timestamp",
- "fixed_interval": f"{request.interval}s",
- "min_doc_count": 1,
+ "fixed_interval": f"{interval}s",
+ "extended_bounds": {
+ "min": int(from_date) * 1000,
+ "max": int(to_date) * 1000,
+ }
},
"aggs": {
agg_type: {es_agg: {"field": "value"}}
@@ -120,7 +127,7 @@ class WorkerStats:
}
query_terms = [
- QueryBuilder.dates_range(request.from_date, request.to_date),
+ QueryBuilder.dates_range(from_date, to_date),
QueryBuilder.terms("metric", {item.key for item in request.items}),
]
if request.worker_ids:
@@ -130,16 +137,16 @@ class WorkerStats:
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
- return self._extract_results(data, request.items, request.split_by_variant)
+ cutoff_date = (to_date - 0.9 * interval) * 1000 # do not return the point for the incomplete last interval
+ return self._extract_results(data, request.items, request.split_by_variant, cutoff_date)
@staticmethod
def _extract_results(
- data: dict, request_items: Sequence[StatItem], split_by_variant: bool
+ data: dict, request_items: Sequence[StatItem], split_by_variant: bool, cutoff_date
) -> dict:
"""
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
leave only aggregation types requested by the user and return a clean dictionary
- and return a "clean" dictionary of
:param data: aggregation data retrieved from ES
:param request_items: aggs types requested by the user
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
@@ -157,7 +164,7 @@ class WorkerStats:
return {
"date": date["key"],
"count": date["doc_count"],
- **{agg: date[agg]["value"] for agg in aggs_per_metric[metric_key]},
+ **{agg: date[agg]["value"] or 0.0 for agg in aggs_per_metric[metric_key]},
}
def extract_metric_results(
@@ -166,7 +173,7 @@ class WorkerStats:
return [
extract_date_stats(date, metric_key)
for date in metric_or_variant["dates"]["buckets"]
- if date["doc_count"]
+ if date["key"] <= cutoff_date
]
def extract_variant_results(metric: dict) -> dict:
diff --git a/apiserver/config/default/services/serving.conf b/apiserver/config/default/services/serving.conf
new file mode 100644
index 0000000..4279e00
--- /dev/null
+++ b/apiserver/config/default/services/serving.conf
@@ -0,0 +1,7 @@
+default_container_timeout_sec: 600
+# Auto-register unknown serving containers on status reports and other calls
+container_auto_register: true
+# Assume unknow serving containers have unregistered (i.e. do not raise unregistered error)
+container_auto_unregister: true
+# The minimal sampling interval for serving model monitor chars
+min_chart_interval_sec: 40
diff --git a/apiserver/database/__init__.py b/apiserver/database/__init__.py
index 401034f..abf9e16 100644
--- a/apiserver/database/__init__.py
+++ b/apiserver/database/__init__.py
@@ -37,6 +37,8 @@ OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
class DatabaseEntry(models.Base):
host = StringField(required=True)
alias = StringField()
+ name = StringField()
+ db = StringField()
class DatabaseFactory:
@@ -78,10 +80,13 @@ class DatabaseFactory:
missing.append(key)
continue
- entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
+ settings = {**db_entries.get(key)}
+ if not any(field in settings for field in ("name", "db")):
+ settings["name"] = key
+ entry = cls._create_db_entry(alias=alias, settings=settings)
if override_connection_string:
- con_str = f"{override_connection_string.rstrip('/')}/{key}"
+ con_str = override_connection_string
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
entry.host = con_str
else:
diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py
index 238b53c..0d5c5cd 100644
--- a/apiserver/database/model/base.py
+++ b/apiserver/database/model/base.py
@@ -1,5 +1,5 @@
import re
-from collections import namedtuple, defaultdict
+from collections import defaultdict
from datetime import datetime
from functools import reduce, partial
from typing import (
@@ -107,7 +107,18 @@ class GetMixin(PropsMixin):
("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b,
}
- MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
+
+ @attr.s(auto_attribs=True)
+ class MultiFieldParameters:
+ fields: Sequence[str]
+ pattern: str = None
+ datetime: Union[list, str] = None
+
+ def __attrs_post_init__(self):
+ if not any(f is not None for f in (self.pattern, self.datetime)):
+ raise ValueError("Either 'pattern' or 'datetime' should be provided")
+ if all(f is not None for f in (self.pattern, self.datetime)):
+ raise ValueError("Only one of the 'pattern' and 'datetime' can be provided")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
@@ -323,6 +334,8 @@ class GetMixin(PropsMixin):
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- : {fields: [, , ...], pattern: } Will query for items where any or all
provided fields match the provided pattern.
+ - : {fields: [, , ...], datetime: } Will query for items where any or all
+ provided datetime fields match the provided condition.
:return: mongoengine.Q query object
"""
return cls._prepare_query_no_company(
@@ -376,6 +389,46 @@ class GetMixin(PropsMixin):
return cls._try_convert_to_numeric(value)
return value
+ @classmethod
+ def _get_dates_query(cls, field: str, data: Union[list, str]) -> Union[Q, dict]:
+ """
+ Return dates query for the field
+ If the data is 2 values array and none of the values starts from dates comparison operations
+ then return the simplified range query
+ Otherwise return the dictionary of dates conditions
+ """
+ if not isinstance(data, list):
+ data = [data]
+
+ if len(data) == 2 and not any(
+ d.startswith(mod)
+ for d in data
+ if d is not None
+ for mod in ACCESS_MODIFIER
+ ):
+ return cls.get_range_field_query(field, data)
+
+ dict_query = {}
+ for d in data:
+ m = ACCESS_REGEX.match(d)
+ if not m:
+ continue
+
+ try:
+ value = parse_datetime(m.group("value"))
+ prefix = m.group("prefix")
+ modifier = ACCESS_MODIFIER.get(prefix)
+ f = (
+ field
+ if not modifier
+ else "__".join((field, modifier))
+ )
+ dict_query[f] = value
+ except (ValueError, OverflowError):
+ pass
+
+ return dict_query
+
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@@ -446,33 +499,11 @@ class GetMixin(PropsMixin):
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
- if not isinstance(data, list):
- data = [data]
- # date time fields also support simplified range queries. Check if this is the case
- if len(data) == 2 and not any(
- d.startswith(mod)
- for d in data
- if d is not None
- for mod in ACCESS_MODIFIER
- ):
- query &= cls.get_range_field_query(field, data)
- else:
- for d in data: # type: str
- m = ACCESS_REGEX.match(d)
- if not m:
- continue
- try:
- value = parse_datetime(m.group("value"))
- prefix = m.group("prefix")
- modifier = ACCESS_MODIFIER.get(prefix)
- f = (
- field
- if not modifier
- else "__".join((field, modifier))
- )
- dict_query[f] = value
- except (ValueError, OverflowError):
- pass
+ dates_q = cls._get_dates_query(field, data)
+ if isinstance(dates_q, Q):
+ query &= dates_q
+ elif isinstance(dates_q, dict):
+ dict_query.update(dates_q)
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
@@ -484,27 +515,40 @@ class GetMixin(PropsMixin):
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
- if any("._" in f for f in data.fields):
- q = reduce(
- lambda a, x: func(
- a,
- RegexQ(
- __raw__={
- x: {"$regex": data.pattern, "$options": "i"}
- }
+ if data.pattern is not None:
+ if any("._" in f for f in data.fields):
+ q = reduce(
+ lambda a, x: func(
+ a,
+ RegexQ(
+ __raw__={
+ x: {"$regex": data.pattern, "$options": "i"}
+ }
+ ),
),
- ),
- data.fields,
- RegexQ(),
- )
+ data.fields,
+ RegexQ(),
+ )
+ else:
+ regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
+ sep_fields = [f.replace(".", "__") for f in data.fields]
+ q = reduce(
+ lambda a, x: func(a, RegexQ(**{x: regex})),
+ sep_fields,
+ RegexQ(),
+ )
else:
- regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
- sep_fields = [f.replace(".", "__") for f in data.fields]
- q = reduce(
- lambda a, x: func(a, RegexQ(**{x: regex})),
- sep_fields,
- RegexQ(),
- )
+ date_fields = [field for field in data.fields if field in opts.datetime_fields]
+ if not date_fields:
+ break
+
+ q = Q()
+ for date_f in date_fields:
+ dates_q = cls._get_dates_query(date_f, data.datetime)
+ if isinstance(dates_q, dict):
+ dates_q = RegexQ(**dates_q)
+ q = func(q, dates_q)
+
query = query & q
except APIError:
raise
@@ -1394,7 +1438,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
else:
items = list(
cls.objects(
- id__in=ids, company__in=(None, ""), company_origin=company_id
+ id__in=ids, company="", company_origin=company_id
).only("id")
)
update: dict = dict(set__company=company_id, unset__company_origin=1)
diff --git a/apiserver/database/model/model.py b/apiserver/database/model/model.py
index 7516312..977e3e4 100644
--- a/apiserver/database/model/model.py
+++ b/apiserver/database/model/model.py
@@ -37,10 +37,18 @@ class Model(AttributedDocument):
"project",
"task",
"last_update",
- ("company", "framework"),
+ ("company", "last_update"),
("company", "name"),
- ("company", "user"),
("company", "uri"),
+ # distinct queries support
+ ("company", "tags"),
+ ("company", "system_tags"),
+ ("company", "project", "tags"),
+ ("company", "project", "system_tags"),
+ ("company", "user"),
+ ("company", "project", "user"),
+ ("company", "framework"),
+ ("company", "project", "framework"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@@ -71,8 +79,8 @@ class Model(AttributedDocument):
"parent",
"metadata.*",
),
- range_fields=("last_metrics.*", "last_iteration"),
- datetime_fields=("last_update",),
+ range_fields=("created", "last_metrics.*", "last_iteration"),
+ datetime_fields=("last_update", "last_change"),
)
id = StringField(primary_key=True)
diff --git a/apiserver/database/model/queue.py b/apiserver/database/model/queue.py
index aab7b49..f8d66db 100644
--- a/apiserver/database/model/queue.py
+++ b/apiserver/database/model/queue.py
@@ -47,6 +47,7 @@ class Queue(DbModelMixin, Document):
name = StrippedStringField(
required=True, unique_with="company", min_length=3, user_set_allowed=True
)
+ display_name = StringField(user_set_allowed=True)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(
diff --git a/apiserver/database/model/storage_settings.py b/apiserver/database/model/storage_settings.py
new file mode 100644
index 0000000..7ab7dd3
--- /dev/null
+++ b/apiserver/database/model/storage_settings.py
@@ -0,0 +1,76 @@
+from mongoengine import (
+ Document,
+ EmbeddedDocument,
+ StringField,
+ DateTimeField,
+ EmbeddedDocumentListField,
+ EmbeddedDocumentField,
+ BooleanField,
+)
+
+from apiserver.database import Database, strict
+from apiserver.database.model import DbModelMixin
+from apiserver.database.model.base import ProperDictMixin
+
+class AWSBucketSettings(EmbeddedDocument, ProperDictMixin):
+ bucket = StringField()
+ subdir = StringField()
+ host = StringField()
+ key = StringField()
+ secret = StringField()
+ token = StringField()
+ multipart = BooleanField()
+ acl = StringField()
+ secure = BooleanField()
+ region = StringField()
+ verify = BooleanField()
+ use_credentials_chain = BooleanField()
+
+
+class AWSSettings(EmbeddedDocument, DbModelMixin):
+ key = StringField()
+ secret = StringField()
+ region = StringField()
+ token = StringField()
+ use_credentials_chain = BooleanField()
+ buckets = EmbeddedDocumentListField(AWSBucketSettings)
+
+
+class GoogleBucketSettings(EmbeddedDocument, ProperDictMixin):
+ bucket = StringField()
+ subdir = StringField()
+ project = StringField()
+ credentials_json = StringField()
+
+
+class GoogleStorageSettings(EmbeddedDocument, DbModelMixin):
+ project = StringField()
+ credentials_json = StringField()
+ buckets = EmbeddedDocumentListField(GoogleBucketSettings)
+
+
+class AzureStorageContainerSettings(EmbeddedDocument, ProperDictMixin):
+ account_name = StringField(required=True)
+ account_key = StringField(required=True)
+ container_name = StringField()
+
+
+class AzureStorageSettings(EmbeddedDocument, DbModelMixin):
+ containers = EmbeddedDocumentListField(AzureStorageContainerSettings)
+
+
+class StorageSettings(DbModelMixin, Document):
+ meta = {
+ "db_alias": Database.backend,
+ "strict": strict,
+ "indexes": [
+ "company"
+ ],
+ }
+
+ id = StringField(primary_key=True)
+ company = StringField(required=True, unique=True)
+ last_update = DateTimeField()
+ aws: AWSSettings = EmbeddedDocumentField(AWSSettings)
+ google: GoogleStorageSettings = EmbeddedDocumentField(GoogleStorageSettings)
+ azure: AzureStorageSettings = EmbeddedDocumentField(AzureStorageSettings)
diff --git a/apiserver/database/model/task/metrics.py b/apiserver/database/model/task/metrics.py
index f94a6f2..af3488c 100644
--- a/apiserver/database/model/task/metrics.py
+++ b/apiserver/database/model/task/metrics.py
@@ -5,6 +5,7 @@ from mongoengine import (
LongField,
EmbeddedDocumentField,
IntField,
+ FloatField,
)
from apiserver.database.fields import SafeMapField
@@ -23,6 +24,11 @@ class MetricEvent(EmbeddedDocument):
min_value_iteration = IntField()
max_value = DynamicField() # for backwards compatibility reasons
max_value_iteration = IntField()
+ first_value = FloatField()
+ first_value_iteration = IntField()
+ count = IntField()
+ mean_value = FloatField()
+ x_axis_label = StringField()
class EventStats(EmbeddedDocument):
diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py
index 5d220a7..42c7816 100644
--- a/apiserver/database/model/task/task.py
+++ b/apiserver/database/model/task/task.py
@@ -183,9 +183,8 @@ class Task(AttributedDocument):
"status_changed",
"models.input.model",
("company", "name"),
- ("company", "user"),
("company", "status", "type"),
- ("company", "system_tags", "last_update"),
+ ("company", "last_update", "system_tags"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
@@ -193,6 +192,17 @@ class Task(AttributedDocument):
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
+ # distinct queries support
+ ("company", "tags"),
+ ("company", "system_tags"),
+ ("company", "project", "tags"),
+ ("company", "project", "system_tags"),
+ ("company", "user"),
+ ("company", "project", "user"),
+ ("company", "parent"),
+ ("company", "project", "parent"),
+ ("company", "type"),
+ ("company", "project", "type"),
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
@@ -233,8 +243,8 @@ class Task(AttributedDocument):
"execution.queue",
"models.input.model",
),
- range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
- datetime_fields=("status_changed", "last_update"),
+ range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"),
+ datetime_fields=("status_changed", "last_update", "last_change"),
pattern_fields=("name", "comment", "report"),
fields=("runtime.*",),
)
diff --git a/apiserver/database/model/user.py b/apiserver/database/model/user.py
index 9281908..4728804 100644
--- a/apiserver/database/model/user.py
+++ b/apiserver/database/model/user.py
@@ -20,4 +20,5 @@ class User(DbModelMixin, Document):
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = DynamicField(default="", exclude_by_default=True)
+ created_in_version = StringField()
created = DateTimeField()
diff --git a/apiserver/database/utils.py b/apiserver/database/utils.py
index 7ee28c7..de4e57e 100644
--- a/apiserver/database/utils.py
+++ b/apiserver/database/utils.py
@@ -121,8 +121,8 @@ def init_cls_from_base(cls, instance):
)
-def get_company_or_none_constraint(company=None):
- return Q(company__in=(company, None, "")) | Q(company__exists=False)
+def get_company_or_none_constraint(company=""):
+ return Q(company__in=list({company, ""}))
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
diff --git a/apiserver/documentation/api_versions.md b/apiserver/documentation/api_versions.md
index 2f900ba..5b987e1 100644
--- a/apiserver/documentation/api_versions.md
+++ b/apiserver/documentation/api_versions.md
@@ -2,6 +2,7 @@
| Release | ApiVersion |
|---------|------------|
+| v1.17 | 2.31 |
| v1.16 | 2.30 |
| v1.15 | 2.29 |
| v1.14 | 2.28 |
diff --git a/apiserver/elastic/index_templates/workers/serving_stats.json b/apiserver/elastic/index_templates/workers/serving_stats.json
new file mode 100644
index 0000000..b1e255d
--- /dev/null
+++ b/apiserver/elastic/index_templates/workers/serving_stats.json
@@ -0,0 +1,79 @@
+{
+ "index_patterns": "serving_stats_*",
+ "template": {
+ "settings": {
+ "number_of_replicas": 0,
+ "number_of_shards": 1
+ },
+ "mappings": {
+ "_source": {
+ "enabled": true
+ },
+ "properties": {
+ "timestamp": {
+ "type": "date"
+ },
+ "container_id": {
+ "type": "keyword"
+ },
+ "company_id": {
+ "type": "keyword"
+ },
+ "endpoint_url": {
+ "type": "keyword"
+ },
+ "requests_num": {
+ "type": "integer"
+ },
+ "requests_min": {
+ "type": "float"
+ },
+ "uptime_sec": {
+ "type": "integer"
+ },
+ "latency_ms": {
+ "type": "integer"
+ },
+ "cpu_usage": {
+ "type": "float"
+ },
+ "cpu_num": {
+ "type": "integer"
+ },
+ "gpu_usage": {
+ "type": "float"
+ },
+ "gpu_num": {
+ "type": "integer"
+ },
+ "memory_used": {
+ "type": "float"
+ },
+ "memory_free": {
+ "type": "float"
+ },
+ "memory_total": {
+ "type": "float"
+ },
+ "gpu_memory_used": {
+ "type": "float"
+ },
+ "gpu_memory_free": {
+ "type": "float"
+ },
+ "gpu_memory_total": {
+ "type": "float"
+ },
+ "disk_free_home": {
+ "type": "float"
+ },
+ "network_rx": {
+ "type": "float"
+ },
+ "network_tx": {
+ "type": "float"
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/apiserver/fix_mongo_urls.py b/apiserver/fix_mongo_urls.py
new file mode 100644
index 0000000..706f791
--- /dev/null
+++ b/apiserver/fix_mongo_urls.py
@@ -0,0 +1,122 @@
+import logging
+from argparse import (
+ ArgumentDefaultsHelpFormatter,
+ ArgumentParser,
+ ArgumentTypeError,
+)
+
+from pymongo import MongoClient
+from pymongo.collection import Collection
+from pymongo.database import Database
+
+
+logging.getLogger().setLevel(logging.INFO)
+
+
+def fix_mongo_urls(mongo_host: str, host_source: str, host_target: str):
+ logging.info(f"Connecting to Mongo on {mongo_host}")
+ client = MongoClient(host=mongo_host)
+ backend_db: Database = client.backend
+
+ def get_updated_uri(uri: str):
+ if not uri or not uri.startswith(host_source):
+ return
+ relative_url = uri[len(host_source) :]
+ return f"{host_target.rstrip('/')}/{relative_url.lstrip('/')}"
+
+ host_source = host_source
+ host_target = host_target
+ model_collection: Collection = backend_db.get_collection("model")
+ if model_collection is not None:
+ logging.info("Updating model uris")
+ models_count = model_collection.count_documents({})
+ updated_models = 0
+ for model in model_collection.find(
+ {"uri": {"$regex": "^{}".format(host_source)}}, projection=["uri"]
+ ):
+ updated_uri = get_updated_uri(model.get("uri"))
+ if updated_uri:
+ result = model_collection.update_one(
+ {"_id": model["_id"]}, {"$set": {"uri": updated_uri}}
+ )
+ updated_models += result.modified_count
+
+ logging.info(f"Updated {updated_models} models from {models_count}")
+
+ task_collection: Collection = backend_db.get_collection("task")
+ if task_collection is not None:
+ logging.info("Updating task uris")
+ tasks_count = task_collection.count_documents({})
+ updated_tasks = 0
+ for task in task_collection.find(
+ {"execution.artifacts": {"$exists": 1, "$ne": {}}},
+ projection=["execution.artifacts"],
+ ):
+ artifacts = task.get("execution", {}).get("artifacts")
+ if not artifacts:
+ continue
+
+ uri_updated = False
+ for artifact in artifacts.values():
+ updated_uri = get_updated_uri(artifact.get("uri"))
+ if updated_uri:
+ artifact["uri"] = updated_uri
+ uri_updated = True
+
+ if uri_updated:
+ result = task_collection.update_one(
+ {"_id": task["_id"]}, {"$set": {"execution.artifacts": artifacts}}
+ )
+ updated_tasks += result.modified_count
+
+ logging.info(f"Updated {updated_tasks} tasks from {tasks_count}")
+
+
+def normalise_host(host):
+ if not host.endswith("/"):
+ return host
+ return host[:-1]
+
+
+def main():
+ def valid_url_prefix(url: str):
+ if "://" not in url:
+ raise ArgumentTypeError("url schema is missing")
+ return url
+
+ parser = ArgumentParser(
+ description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--mongo-host",
+ "-mh",
+ type=str,
+ default="mongodb://mongo:27017",
+ help="Mongo server host. The default is mongodb://mongo:27017",
+ )
+ parser.add_argument(
+ "--host-source",
+ "-hs",
+ type=valid_url_prefix,
+ required=True,
+ help="Source host for the files uploaded to the fileserver (in the form http://:)",
+ )
+ parser.add_argument(
+ "--host-target",
+ "-ht",
+ type=valid_url_prefix,
+ required=True,
+ help="Target host for the files uploaded to the fileserver (in the form http://:)",
+ )
+ args = parser.parse_args()
+
+ fix_mongo_urls(
+ mongo_host=args.mongo_host,
+ host_source=args.host_source,
+ host_target=args.host_target,
+ )
+ logging.info("Completed successfully")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/apiserver/mongo/initialize/migration.py b/apiserver/mongo/initialize/migration.py
index 8837604..081d44d 100644
--- a/apiserver/mongo/initialize/migration.py
+++ b/apiserver/mongo/initialize/migration.py
@@ -8,13 +8,16 @@ import pymongo.database
from mongoengine.connection import get_db
from packaging.version import Version, parse
+from apiserver.config_repo import config
from apiserver.database import utils
from apiserver.database import Database
from apiserver.database.model.version import Version as DatabaseVersion
+from apiserver.utilities.dicts import nested_get
_migrations = "migrations"
_parent_dir = Path(__file__).resolve().parents[1]
_migration_dir = _parent_dir / _migrations
+log = config.logger(__file__)
def check_mongo_empty() -> bool:
@@ -41,6 +44,26 @@ def get_last_server_version() -> Version:
return previous_versions[0] if previous_versions else Version("0.0.0")
+def _ensure_mongodb_version():
+ db: pymongo.database.Database = get_db(Database.backend)
+ db_version = db.client.server_info()["version"]
+ if not db_version.startswith("6.0"):
+ log.warning(f"Database version should be 6.0.x. Instead: {str(db_version)}")
+ return
+
+ res = db.client.admin.command({"getParameter": 1, "featureCompatibilityVersion": 1})
+ version = nested_get(res, ("featureCompatibilityVersion", "version"))
+ if version == "6.0":
+ return
+ if version != "5.0":
+ log.warning(f"Cannot upgrade DB version. Should be 5.0. {str(res)}")
+ return
+
+ log.info("Upgrading db version from 5.0 to 6.0")
+ res = db.client.admin.command({"setFeatureCompatibilityVersion": "6.0"})
+ log.info(res)
+
+
def _apply_migrations(log: Logger):
"""
Apply migrations as found in the migration dir.
@@ -50,6 +73,8 @@ def _apply_migrations(log: Logger):
log.info(f"Started mongodb migrations")
+ _ensure_mongodb_version()
+
if not _migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {_migration_dir}")
diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py
index 9f80bd1..0cc6fdc 100644
--- a/apiserver/mongo/initialize/pre_populate.py
+++ b/apiserver/mongo/initialize/pre_populate.py
@@ -28,10 +28,11 @@ from urllib.parse import unquote, urlparse
from uuid import uuid4, UUID, uuid5
from zipfile import ZipFile, ZIP_BZIP2
+import attr
import mongoengine
from boltons.iterutils import chunked_iter, first
from furl import furl
-from mongoengine import Q
+from mongoengine import Q, Document
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
@@ -61,6 +62,8 @@ from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
+replace_s3_scheme = os.getenv("CLEARML_REPLACE_S3_SCHEME")
+
class PrePopulate:
module_name_prefix = "apiserver."
@@ -84,6 +87,11 @@ class PrePopulate:
user_cls: Type[User]
auth_user_cls: Type[AuthUser]
+ @attr.s(auto_attribs=True)
+ class ParentPrefix:
+ prefix: str
+ path: Sequence[str]
+
# noinspection PyTypeChecker
@classmethod
def _init_entity_types(cls):
@@ -469,20 +477,35 @@ class PrePopulate:
@classmethod
def _check_projects_hierarchy(cls, projects: Set[Project]):
"""
- For any exported project all its parents up to the root should be present
+ For the projects that are exported not from the root
+ fix their parents tree to exclude the not exported parents
"""
if not projects:
return
project_ids = {p.id for p in projects}
- orphans = [p.id for p in projects if p.parent and p.parent not in project_ids]
+ orphans = [p for p in projects if p.parent and p.parent not in project_ids]
if not orphans:
return
- print(
- f"ERROR: the following projects are exported without their parents: {orphans}"
- )
- exit(1)
+ prefixes = [
+ cls.ParentPrefix(prefix=f"{project.name.rpartition('/')[0]}/", path=project.path)
+ for project in orphans
+ ]
+ prefixes.sort(key=lambda p: len(p.path), reverse=True)
+ for project in projects:
+ prefix = first(pref for pref in prefixes if project.path[:len(pref.path)] == pref.path)
+ if not prefix:
+ continue
+ project.path = project.path[len(prefix.path):]
+ if not project.path:
+ project.parent = None
+ project.name = project.name.removeprefix(prefix.prefix)
+
+ # print(
+ # f"ERROR: the following projects are exported without their parents: {orphans}"
+ # )
+ # exit(1)
@classmethod
def _resolve_entities(
@@ -491,6 +514,7 @@ class PrePopulate:
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
+ # noinspection PyTypeChecker
entities: Dict[Any] = defaultdict(set)
if projects:
@@ -539,6 +563,7 @@ class PrePopulate:
print("Reading models...")
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
+ # noinspection PyTypeChecker
return entities
@classmethod
@@ -643,8 +668,9 @@ class PrePopulate:
@staticmethod
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
- if not (url and url.lower().startswith("s3://")):
+ if not (replace_s3_scheme and url and url.lower().startswith("s3://")):
return url
+
try:
fixed = furl(url)
fixed.scheme = "https"
@@ -983,8 +1009,10 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
- @staticmethod
- def _upgrade_project_data(project_data: dict) -> dict:
+ @classmethod
+ def _upgrade_project_data(cls, project_data: dict) -> dict:
+ cls._remove_incompatible_fields(cls.project_cls, project_data)
+
if not project_data.get("basename"):
name: str = project_data["name"]
_, _, basename = name.rpartition("/")
@@ -992,8 +1020,10 @@ class PrePopulate:
return project_data
- @staticmethod
- def _upgrade_model_data(model_data: dict) -> dict:
+ @classmethod
+ def _upgrade_model_data(cls, model_data: dict) -> dict:
+ cls._remove_incompatible_fields(cls.model_cls, model_data)
+
metadata_key = "metadata"
metadata = model_data.get(metadata_key)
if isinstance(metadata, list):
@@ -1006,7 +1036,13 @@ class PrePopulate:
return model_data
@staticmethod
- def _upgrade_task_data(task_data: dict) -> dict:
+ def _remove_incompatible_fields(cls_: Type[Document], data: dict):
+ for field in ("company_origin",):
+ if field not in cls_._db_field_map:
+ data.pop(field, None)
+
+ @classmethod
+ def _upgrade_task_data(cls, task_data: dict) -> dict:
"""
Migrate from execution/parameters and model_desc to hyperparams and configuration fiields
Upgrade artifacts list to dict
@@ -1015,6 +1051,8 @@ class PrePopulate:
:param task_data: Upgraded in place
:return: The upgraded task data
"""
+ cls._remove_incompatible_fields(cls.task_cls, task_data)
+
for old_param_field, new_param_field, default_section in (
("execution.parameters", "hyperparams", hyperparams_default_section),
("execution.model_desc", "configuration", None),
@@ -1133,7 +1171,7 @@ class PrePopulate:
if isinstance(doc, cls.task_cls):
tasks.append(doc)
- cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
+ cls.event_bll.delete_task_events(company_id, doc.id, wait_for_delete=True)
if tasks:
return tasks
diff --git a/apiserver/requirements.txt b/apiserver/requirements.txt
index 0a2fa36..e39c07a 100644
--- a/apiserver/requirements.txt
+++ b/apiserver/requirements.txt
@@ -6,7 +6,7 @@ boto3>=1.26
boto3-stubs[s3]>=1.26
clearml>=1.10.3
dpath>=1.4.2,<2.0
-elasticsearch==8.12.0
+elasticsearch==8.17.0
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
@@ -19,15 +19,15 @@ jinja2
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
-mongoengine==0.27.0
+mongoengine==0.29.1
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35r
pyjwt>=2.4.0
-pymongo==4.6.3
+pymongo==4.10.1
python-rapidjson>=0.6.3
-redis>=4.5.4,<5
+redis==5.2.1
requests>=2.13.0
semantic_version>=2.8.3,<3
setuptools>=65.5.1
diff --git a/apiserver/schema/services/_common.conf b/apiserver/schema/services/_common.conf
index 6982e52..f075047 100644
--- a/apiserver/schema/services/_common.conf
+++ b/apiserver/schema/services/_common.conf
@@ -74,7 +74,11 @@ multi_field_pattern_data {
type: object
properties {
pattern {
- description: "Pattern string (regex)"
+ description: "Pattern string (regex). Either 'pattern' or 'datetime' should be specified"
+ type: string
+ }
+ datetime {
+ description: "Date time conditions (applicable only to datetime fields). Either 'pattern' or 'datetime' should be specified"
type: string
}
fields {
diff --git a/apiserver/schema/services/_tasks_common.conf b/apiserver/schema/services/_tasks_common.conf
index 2afffe7..3e7e587 100644
--- a/apiserver/schema/services/_tasks_common.conf
+++ b/apiserver/schema/services/_tasks_common.conf
@@ -283,6 +283,26 @@ last_metrics_event {
description: "The iteration at which the maximum value was reported"
type: integer
}
+ first_value {
+ description: "First value reported"
+ type: number
+ }
+ first_value_iteration {
+ description: "The iteration at which the first value was reported"
+ type: integer
+ }
+ mean_value {
+ description: "The mean value"
+ type: number
+ }
+ count {
+ description: "The total count of reported values"
+ type: integer
+ }
+ x_axis_label {
+ description: The user defined value for the X-Axis name stored with the event
+ type: string
+ }
}
}
last_metrics_variants {
diff --git a/apiserver/schema/services/_workers_common.conf b/apiserver/schema/services/_workers_common.conf
new file mode 100644
index 0000000..034abd3
--- /dev/null
+++ b/apiserver/schema/services/_workers_common.conf
@@ -0,0 +1,67 @@
+machine_stats {
+ type: object
+ properties {
+ cpu_usage {
+ description: "Average CPU usage per core"
+ type: array
+ items { type: number }
+ }
+ gpu_usage {
+ description: "Average GPU usage per GPU card"
+ type: array
+ items { type: number }
+ }
+ memory_used {
+ description: "Used memory MBs"
+ type: number
+ }
+ memory_free {
+ description: "Free memory MBs"
+ type: number
+ }
+ gpu_memory_free {
+ description: "GPU free memory MBs"
+ type: array
+ items { type: number }
+ }
+ gpu_memory_used {
+ description: "GPU used memory MBs"
+ type: array
+ items { type: number }
+ }
+ network_tx {
+ description: "Mbytes per second"
+ type: number
+ }
+ network_rx {
+ description: "Mbytes per second"
+ type: number
+ }
+ disk_free_home {
+ description: "Free space in % of /home drive"
+ type: number
+ }
+ disk_free_temp {
+ description: "Free space in % of /tmp drive"
+ type: number
+ }
+ disk_read {
+ description: "Mbytes read per second"
+ type: number
+ }
+ disk_write {
+ description: "Mbytes write per second"
+ type: number
+ }
+ cpu_temperature {
+ description: "CPU temperature"
+ type: array
+ items { type: number }
+ }
+ gpu_temperature {
+ description: "GPU temperature"
+ type: array
+ items { type: number }
+ }
+ }
+}
\ No newline at end of file
diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf
index d885673..4d909a4 100644
--- a/apiserver/schema/services/events.conf
+++ b/apiserver/schema/services/events.conf
@@ -27,13 +27,17 @@ _definitions {
type: string
}
variant {
- description: "E.g. 'class_1', 'total', 'average"
+ description: "E.g. 'class_1', 'total', 'average'"
type: string
}
value {
description: ""
type: number
}
+ x_axis_label {
+ description: "Custom X-Axis label to be used when displaying the scalars histogram"
+ type: string
+ }
}
}
metrics_vector_event {
diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf
index 574a2fd..a9d02e7 100644
--- a/apiserver/schema/services/models.conf
+++ b/apiserver/schema/services/models.conf
@@ -1,20 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_tasks_common.conf"
- multi_field_pattern_data {
- type: object
- properties {
- pattern {
- description: "Pattern string (regex)"
- type: string
- }
- fields {
- description: "List of field names"
- type: array
- items { type: string }
- }
- }
- }
model {
type: object
properties {
diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf
index 168cd7c..06b26d6 100644
--- a/apiserver/schema/services/projects.conf
+++ b/apiserver/schema/services/projects.conf
@@ -1,20 +1,6 @@
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
include "_common.conf"
- multi_field_pattern_data {
- type: object
- properties {
- pattern {
- description: "Pattern string (regex)"
- type: string
- }
- fields {
- description: "List of field names"
- type: array
- items { type: string }
- }
- }
- }
project {
type: object
properties {
diff --git a/apiserver/schema/services/queues.conf b/apiserver/schema/services/queues.conf
index 7617fde..96d5a62 100644
--- a/apiserver/schema/services/queues.conf
+++ b/apiserver/schema/services/queues.conf
@@ -50,6 +50,10 @@ _definitions {
description: "Queue name"
type: string
}
+ display_name {
+ description: "Display name"
+ type: string
+ }
user {
description: "Associated user id"
type: string
@@ -324,7 +328,7 @@ create {
}
}
"2.13": ${create."2.4"} {
- metadata {
+ request.properties.metadata {
description: "Queue metadata"
type: object
additionalProperties {
@@ -332,6 +336,12 @@ create {
}
}
}
+ "2.31": ${create."2.13"} {
+ request.properties.display_name {
+ description: "Display name"
+ type: string
+ }
+ }
}
update {
"2.4" {
@@ -377,7 +387,7 @@ update {
}
}
"2.13": ${update."2.4"} {
- metadata {
+ request.properties.metadata {
description: "Queue metadata"
type: object
additionalProperties {
@@ -385,6 +395,12 @@ update {
}
}
}
+ "2.31": ${update."2.13"} {
+ request.properties.display_name {
+ description: "Display name"
+ type: string
+ }
+ }
}
delete {
"2.4" {
@@ -447,6 +463,13 @@ add_task {
}
}
}
+ "2.31": ${add_task."2.4"} {
+ request.properties.update_execution_queue {
+ description: If set to false then the task 'execution.queue' is not updated
+ type: boolean
+ default: true
+ }
+ }
}
get_next_task {
"2.4" {
@@ -530,8 +553,41 @@ remove_task {
}
}
}
+ "2.31": ${remove_task."2.4"} {
+ request.properties {
+ update_task_status {
+ type: boolean
+ default: false
+ description: If set to 'true' then change the removed task status to the one it had prior to enqueuing or 'created'
+ }
+ }
+ }
+}
+clear_queue {
+ "2.31" {
+ description: Remove all tasks from the queue and change their statuses to what they were prior to enqueuing or 'created'
+ request {
+ type: object
+ required: [queue]
+ properties {
+ queue {
+ description: "Queue id"
+ type: string
+ }
+ }
+ }
+ response {
+ type: object
+ properties {
+ removed_tasks {
+ description: IDs of the removed tasks
+ type: array
+ items {type: string}
+ }
+ }
+ }
+ }
}
-
move_task_forward: {
"2.4" {
description: "Moves a task entry one step forward towards the top of the queue."
diff --git a/apiserver/schema/services/serving.conf b/apiserver/schema/services/serving.conf
new file mode 100644
index 0000000..e280059
--- /dev/null
+++ b/apiserver/schema/services/serving.conf
@@ -0,0 +1,437 @@
+_description: "Serving apis"
+_definitions {
+ include "_workers_common.conf"
+ reference_item {
+ type: object
+ required = [type, value]
+ properties {
+ type {
+ description: The type of the reference item
+ type: string
+ enum: [app_id, app_instance, model, task, url]
+ }
+ value {
+ description: The reference item value
+ type: string
+ }
+ }
+ }
+ reference {
+ description: Array of reference items provided by the container instance. Can contain multiple reference items with the same type
+ type: array
+ items: ${_definitions.reference_item}
+ }
+ serving_model_report {
+ type: object
+ required: [container_id, endpoint_name, model_name]
+ properties {
+ container_id {
+ type: string
+ description: Container ID. Should uniquely identify a specific container instance
+ }
+ endpoint_name {
+ type: string
+ description: Endpoint name
+ }
+ endpoint_url {
+ type: string
+ description: Endpoint URL
+ }
+ model_name {
+ type: string
+ description: Model name
+ }
+ model_source {
+ type: string
+ description: Model source
+ }
+ model_version {
+ type: string
+ description: Model version
+ }
+ preprocess_artifact {
+ type: string
+ description: Preprocess Artifact
+ }
+ input_type {
+ type: string
+ description: Input type
+ }
+ input_size {
+ type: string
+ description: Input size
+ }
+ reference: ${_definitions.reference}
+ }
+ }
+ endpoint_stats {
+ type: object
+ properties {
+ endpoint {
+ type: string
+ description: Endpoint name
+ }
+ model {
+ type: string
+ description: Model name
+ }
+ url {
+ type: string
+ description: Model url
+ }
+ instances {
+ type: integer
+ description: The number of model serving instances
+ }
+ uptime_sec {
+ type: integer
+ description: Max of model instance uptime in seconds
+ }
+ requests {
+ type: integer
+ description: Total requests processed by model instances
+ }
+ requests_min {
+ type: number
+ description: Average of request rate of model instances per minute
+ }
+ latency_ms {
+ type: integer
+ description: Average of latency of model instances in ms
+ }
+ last_update {
+ type: string
+ format: "date-time"
+ description: The latest time when one of the model instances was updated
+ }
+ }
+ }
+ container_instance_stats {
+ type: object
+ properties {
+ id {
+ type: string
+ description: Container ID
+ }
+ uptime_sec {
+ type: integer
+ description: Uptime in seconds
+ }
+ requests {
+ type: integer
+ description: Number of requests
+ }
+ requests_min {
+ type: number
+ description: Average requests per minute
+ }
+ latency_ms {
+ type: integer
+ description: Average request latency in ms
+ }
+ last_update {
+ type: string
+ format: "date-time"
+ description: The latest time when the container instance sent update
+ }
+ cpu_count {
+ type: integer
+ description: CPU Count
+ }
+ gpu_count {
+ type: integer
+ description: GPU Count
+ }
+ reference: ${_definitions.reference}
+
+ }
+ }
+ serving_model_info {
+ type: object
+ properties {
+ endpoint {
+ type: string
+ description: Endpoint name
+ }
+ model {
+ type: string
+ description: Model name
+ }
+ url {
+ type: string
+ description: Model url
+ }
+ model_source {
+ type: string
+ description: Model source
+ }
+ model_version {
+ type: string
+ description: Model version
+ }
+ preprocess_artifact {
+ type: string
+ description: Preprocess Artifact
+ }
+ input_type {
+ type: string
+ description: Input type
+ }
+ input_size {
+ type: string
+ description: Input size
+ }
+ }
+ }
+ container_info: ${_definitions.serving_model_info} {
+ properties {
+ id {
+ type: string
+ description: Container ID
+ }
+ uptime_sec {
+ type: integer
+ description: Model instance uptime in seconds
+ }
+ last_update {
+ type: string
+ format: "date-time"
+ description: The latest time when the container instance sent update
+ }
+ age_sec {
+ type: integer
+ description: Amount of seconds since the container registration
+ }
+ }
+ }
+ metrics_history_series {
+ type: object
+ properties {
+ title {
+ type: string
+ description: "The title of the series"
+ }
+ dates {
+ type: array
+ description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval."
+ items {type: integer}
+ }
+ values {
+ type: array
+ description: "List of values corresponding to the timestamps in the dates list."
+ items {type: number}
+ }
+ }
+ }
+}
+register_container {
+ "2.31" {
+ description: Register container
+ request = ${_definitions.serving_model_report} {
+ properties {
+ timeout {
+ description: "Registration timeout in seconds. If timeout seconds have passed since the service container last call to register or status_report, the container is automatically removed from the list of registered containers."
+ type: integer
+ default: 600
+ }
+ }
+ }
+ response {
+ type: object
+ additionalProperties: false
+ }
+ }
+}
+unregister_container {
+ "2.31" {
+ description: Unregister container
+ request {
+ type: object
+ required: [container_id]
+ properties {
+ container_id {
+ type: string
+ description: Container ID
+ }
+ }
+ }
+ response {
+ type: object
+ additionalProperties: false
+ }
+ }
+}
+container_status_report {
+ "2.31" {
+ description: Container status report
+ request = ${_definitions.serving_model_report} {
+ properties {
+ uptime_sec {
+ type: integer
+ description: Uptime in seconds
+ }
+ requests_num {
+ type: integer
+ description: Number of requests
+ }
+ requests_min {
+ type: number
+ description: Average requests per minute
+ }
+ latency_ms {
+ type: integer
+ description: Average request latency in ms
+ }
+ machine_stats {
+ description: "The machine statistics"
+ "$ref": "#/definitions/machine_stats"
+ }
+ }
+ }
+ response {
+ type: object
+ additionalProperties: false
+ }
+ }
+}
+get_endpoints {
+ "2.31" {
+ description: Get all the registered endpoints
+ request {
+ type: object
+ additionalProperties: false
+ }
+ response {
+ type: object
+ properties {
+ endpoints {
+ type: array
+ items { "$ref": "#/definitions/endpoint_stats" }
+ }
+ }
+ }
+ }
+}
+get_loading_instances {
+ "2.31" {
+ description: "Get loading instances (enpoint_url not set yet)"
+ request {
+ type: object
+ additionalProperties: false
+ }
+ response {
+ type: object
+ properties {
+ instances {
+ type: array
+ items { "$ref": "#/definitions/container_info" }
+ }
+ }
+ }
+ }
+}
+get_endpoint_details {
+ "2.31" {
+ description: Get endpoint details
+ request {
+ type: object
+ required: [endpoint_url]
+ properties {
+ endpoint_url {
+ type: string
+ description: Endpoint URL
+ }
+ }
+ }
+ response: ${_definitions.serving_model_info} {
+ properties {
+ uptime_sec {
+ type: integer
+ description: Max of model instance uptime in seconds
+ }
+ last_update {
+ type: string
+ format: "date-time"
+ description: The latest time when one of the model instances was updated
+ }
+ instances {
+ type: array
+ items {"$ref": "#/definitions/container_instance_stats"}
+ }
+ }
+ }
+ }
+}
+get_endpoint_metrics_history {
+ "2.31" {
+ description: Get endpoint charts
+ request {
+ type: object
+ required: [endpoint_url, from_date, to_date, interval]
+ properties {
+ endpoint_url {
+ description: Endpoint Url
+ type: string
+ }
+ from_date {
+ description: "Starting time (in seconds from epoch) for collecting statistics"
+ type: number
+ }
+ to_date {
+ description: "Ending time (in seconds from epoch) for collecting statistics"
+ type: number
+ }
+ interval {
+ description: "Time interval in seconds for a single statistics point. The minimal value is 1"
+ type: integer
+ }
+ metric_type {
+ description: The type of the metrics to return on the chart
+ type: string
+ default: requests
+ enum: [
+ requests
+ requests_min
+ latency_ms
+ cpu_count
+ gpu_count
+ cpu_util
+ gpu_util
+ ram_total
+ ram_used
+ ram_free
+ gpu_ram_total
+ gpu_ram_used
+ gpu_ram_free
+ network_rx
+ network_tx
+ ]
+ }
+ instance_charts {
+ type: boolean
+ default: true
+ description: If set then return instance charts and total. Otherwise total only
+ }
+ }
+ }
+ response {
+ type: object
+ properties {
+ computed_interval {
+ description: The inteval that was actually used for the histogram. May be larger then the requested one
+ type: integer
+ }
+ total: ${_definitions.metrics_history_series} {
+ properties {
+ description: The total histogram
+ }
+ }
+ instances {
+ description: Instance charts
+ type: object
+ additionalProperties: ${_definitions.metrics_history_series}
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/apiserver/schema/services/storage.conf b/apiserver/schema/services/storage.conf
new file mode 100644
index 0000000..c7166ae
--- /dev/null
+++ b/apiserver/schema/services/storage.conf
@@ -0,0 +1,242 @@
+_description: """This service provides storage settings managmement"""
+_default {
+ internal: true
+}
+
+_definitions {
+ include "_common.conf"
+ aws_bucket {
+ type: object
+ description: Settings per S3 bucket
+ properties {
+ bucket {
+ description: The name of the bucket
+ type: string
+ }
+ subdir {
+ description: The path to match
+ type: string
+ }
+ host {
+ description: Host address (for minio servers)
+ type: string
+ }
+ key {
+ description: Access key
+ type: string
+ }
+ secret {
+ description: Secret key
+ type: string
+ }
+ token {
+ description: Access token
+ type: string
+ }
+ multipart {
+ description: Multipart upload
+ type: boolean
+ default: true
+ }
+ acl {
+ description: ACL
+ type: string
+ }
+ secure {
+ description: Use SSL connection
+ type: boolean
+ default: true
+ }
+ region {
+ description: AWS Region
+ type: string
+ }
+ verify {
+ description: Verify server certificate
+ type: boolean
+ default: true
+ }
+ use_credentials_chain {
+ description: Use host configured credentials
+ type: boolean
+ default: false
+ }
+ }
+ }
+ aws {
+ type: object
+ description: AWS S3 storage settings
+ properties {
+ key {
+ description: Access key
+ type: string
+ }
+ secret {
+ description: Secret key
+ type: string
+ }
+ region {
+ description: AWS region
+ type: string
+ }
+ token {
+ description: Access token
+ type: string
+ }
+ use_credentials_chain {
+ description: If set then use host credentials
+ type: boolean
+ default: false
+ }
+ buckets {
+ description: Credential settings per bucket
+ type: array
+ items {"$ref": "#/definitions/aws_bucket"}
+ }
+ }
+ }
+ google_bucket {
+ type: object
+ description: Settings per Google storage bucket
+ properties {
+ bucket {
+ description: The name of the bucket
+ type: string
+ }
+ project {
+ description: The name of the project
+ type: string
+ }
+ subdir {
+ description: The path to match
+ type: string
+ }
+ credentials_json {
+ description: The contents of the credentials json file
+ type: string
+ }
+ }
+ }
+ google {
+ type: object
+ description: Google storage settings
+ properties {
+ project {
+ description: Project name
+ type: string
+ }
+ credentials_json {
+ description: The contents of the credentials json file
+ type: string
+ }
+ buckets {
+ description: Credentials per bucket
+ type: array
+ items {"$ref": "#/definitions/google_bucket"}
+ }
+ }
+ }
+ azure_container {
+ type: object
+ description: Azure container settings
+ properties {
+ account_name {
+ description: Account name
+ type: string
+ }
+ account_key {
+ description: Account key
+ type: string
+ }
+ container_name {
+ description: The name of the container
+ type: string
+ }
+ }
+ }
+ azure {
+ type: object
+ description: Azure storage settings
+ properties {
+ containers {
+ description: Credentials per container
+ type: array
+ items {"$ref": "#/definitions/azure_container"}
+ }
+ }
+ }
+}
+
+set_settings {
+ "2.31" {
+ description: Set Storage settings
+ request {
+ type: object
+ properties {
+ aws {"$ref": "#/definitions/aws"}
+ google {"$ref": "#/definitions/google"}
+ azure {"$ref": "#/definitions/azure"}
+ }
+ }
+ response {
+ type: object
+ properties {
+ updated {
+ description: "Number of settings documents updated (0 or 1)"
+ type: integer
+ enum: [0, 1]
+ }
+ }
+ }
+ }
+}
+reset_settings {
+ "2.31" {
+ description: Reset selected storage settings
+ request {
+ type: object
+ properties {
+ keys {
+ description: The names of the settings to delete
+ type: array
+ items {
+ type: string
+ enum: ["azure", "aws", "google"]
+ }
+ }
+ }
+ }
+ response {
+ type: object
+ properties {
+ updated {
+ description: "Number of settings documents updated (0 or 1)"
+ type: integer
+ enum: [0, 1]
+ }
+ }
+ }
+ }
+}
+get_settings {
+ "2.22" {
+ description: Get storage settings
+ request {
+ type: object
+ additionalProperties: false
+ }
+ response {
+ type: object
+ properties {
+ last_update {
+ description: "Settings last update time (UTC) "
+ type: string
+ format: "date-time"
+ }
+ aws {"$ref": "#/definitions/aws"}
+ google {"$ref": "#/definitions/google"}
+ azure {"$ref": "#/definitions/azure"}
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf
index 4d54f66..7a37da9 100644
--- a/apiserver/schema/services/tasks.conf
+++ b/apiserver/schema/services/tasks.conf
@@ -1507,6 +1507,13 @@ Fails if the following parameters in the task were not filled:
type: boolean
}
}
+ "2.31": ${enqueue."2.22"} {
+ request.properties.update_execution_queue {
+ description: If set to false then the task 'execution.queue' is not updated. This can be done only for the task that is already enqueued
+ type: boolean
+ default: true
+ }
+ }
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {
diff --git a/apiserver/schema/services/users.conf b/apiserver/schema/services/users.conf
index ed3619d..f4512ff 100644
--- a/apiserver/schema/services/users.conf
+++ b/apiserver/schema/services/users.conf
@@ -147,7 +147,7 @@ get_current_user {
description: Getting stated info
additionalProperties: true
}
- created {
+ user.properties.created {
type: string
description: User creation time
format: date-time
@@ -166,6 +166,14 @@ get_current_user {
}
}
}
+ "2.31": ${get_current_user."2.26"} {
+ response.properties {
+ user.properties.created_in_version {
+ type: string
+ description: Server version at user creation time
+ }
+ }
+ }
}
get_all_ex {
diff --git a/apiserver/schema/services/workers.conf b/apiserver/schema/services/workers.conf
index 42f6fd8..9eeca7f 100644
--- a/apiserver/schema/services/workers.conf
+++ b/apiserver/schema/services/workers.conf
@@ -1,5 +1,6 @@
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
+ include "_workers_common.conf"
metrics_category {
type: object
properties {
@@ -193,6 +194,10 @@ _definitions {
queue_entry = ${_definitions.id_name_entry} {
properties {
+ display_name {
+ description: "Display name for the queue (if defined)"
+ type: string
+ }
next_task {
description: "Next task in the queue"
"$ref": "#/definitions/id_name_entry"
@@ -203,74 +208,6 @@ _definitions {
}
}
}
-
- machine_stats {
- type: object
- properties {
- cpu_usage {
- description: "Average CPU usage per core"
- type: array
- items { type: number }
- }
- gpu_usage {
- description: "Average GPU usage per GPU card"
- type: array
- items { type: number }
- }
- memory_used {
- description: "Used memory MBs"
- type: integer
- }
- memory_free {
- description: "Free memory MBs"
- type: integer
- }
- gpu_memory_free {
- description: "GPU free memory MBs"
- type: array
- items { type: integer }
- }
- gpu_memory_used {
- description: "GPU used memory MBs"
- type: array
- items { type: integer }
- }
- network_tx {
- description: "Mbytes per second"
- type: integer
- }
- network_rx {
- description: "Mbytes per second"
- type: integer
- }
- disk_free_home {
- description: "Mbytes free space of /home drive"
- type: integer
- }
- disk_free_temp {
- description: "Mbytes free space of /tmp drive"
- type: integer
- }
- disk_read {
- description: "Mbytes read per second"
- type: integer
- }
- disk_write {
- description: "Mbytes write per second"
- type: integer
- }
- cpu_temperature {
- description: "CPU temperature"
- type: array
- items { type: number }
- }
- gpu_temperature {
- description: "GPU temperature"
- type: array
- items { type: number }
- }
- }
- }
}
get_all {
"2.4" {
diff --git a/apiserver/server_init/request_handlers.py b/apiserver/server_init/request_handlers.py
index 6e15d10..594bab4 100644
--- a/apiserver/server_init/request_handlers.py
+++ b/apiserver/server_init/request_handlers.py
@@ -2,9 +2,11 @@ import unicodedata
import urllib.parse
from functools import partial
+from boltons.iterutils import first
from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
+from werkzeug.http import quote_header_value
from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError
@@ -21,12 +23,26 @@ log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
+ _basic_cookie_settings = config.get("apiserver.auth.cookies")
_custom_cookie_settings = {
c["name"]: c["settings"]
for c in config.get("apiserver.auth.custom_cookies", {}).values()
if c.get("enabled") and c.get("settings")
}
+ def _get_cookie_settings(self, cookie_key=None):
+ settings = (
+ self._custom_cookie_settings.get(cookie_key) or self._basic_cookie_settings
+ ).copy()
+ if isinstance(settings["domain"], list):
+ host_without_port, _, _ = request.host.partition(":")
+ domain = first(
+ settings["domain"],
+ key=lambda d: host_without_port.endswith(d) if d else False,
+ )
+ settings["domain"] = domain
+ return settings
+
def before_request(self):
if request.method == "OPTIONS":
return "", 200
@@ -54,17 +70,18 @@ class RequestHandlers:
if call.result.filename:
# make sure that downloaded files are not cached by the client
disable_cache = True
+ download_name = call.result.filename
try:
- call.result.filename.encode("ascii")
+ download_name.encode("ascii")
except UnicodeEncodeError:
- simple = unicodedata.normalize("NFKD", call.result.filename)
+ simple = unicodedata.normalize("NFKD", download_name)
simple = simple.encode("ascii", "ignore").decode("ascii")
# safe = RFC 5987 attr-char
- quoted = urllib.parse.quote(call.result.filename, safe="")
- filenames = f"filename={simple}; filename*=UTF-8''{quoted}"
+ quoted = urllib.parse.quote(download_name, safe="")
+ filenames = f"filename={quote_header_value(simple)}; filename*=UTF-8''{quoted}"
else:
- filenames = f"filename={call.result.filename}"
- headers = {"Content-Disposition": "attachment; " + filenames}
+ filenames = f"filename={quote_header_value(download_name)}"
+ headers = {f"Content-Disposition": f"attachment; {filenames}"}
response = Response(
content,
@@ -78,10 +95,7 @@ class RequestHandlers:
if call.result.cookies:
for key, value in call.result.cookies.items():
- kwargs = (
- self._custom_cookie_settings.get(key)
- or config.get("apiserver.auth.cookies")
- ).copy()
+ kwargs = self._get_cookie_settings(key)
if value is None:
# Removing a cookie
kwargs["max_age"] = 0
diff --git a/apiserver/service_repo/service_repo.py b/apiserver/service_repo/service_repo.py
index 643b2d8..17e8928 100644
--- a/apiserver/service_repo/service_repo.py
+++ b/apiserver/service_repo/service_repo.py
@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
- _max_version = PartialVersion("2.30")
+ _max_version = PartialVersion("2.31")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
diff --git a/apiserver/services/events.py b/apiserver/services/events.py
index 724a937..5f94219 100644
--- a/apiserver/services/events.py
+++ b/apiserver/services/events.py
@@ -42,6 +42,7 @@ from apiserver.apimodels.events import (
LegacyMultiTaskEventsRequest,
)
from apiserver.bll.event import EventBLL
+from apiserver.bll.event.event_bll import LOCKED_TASK_STATUSES
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
@@ -52,6 +53,7 @@ from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
+from apiserver.service_repo.auth import Identity
from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL()
@@ -488,6 +490,7 @@ def scalar_metrics_iter_histogram(
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
+ model_events=request.model_events,
)
call.result.data = metrics
@@ -538,12 +541,13 @@ def multi_task_scalar_metrics_iter_histogram(
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
+ model_events=request.model_events,
)
)
def _get_single_value_metrics_response(
- companies: TaskCompanies, value_metrics: Mapping[str, dict]
+ companies: TaskCompanies, value_metrics: Mapping[str, Sequence[dict]]
) -> Sequence[dict]:
task_names = {
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
@@ -1001,30 +1005,64 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
+def _validate_task_for_events_update(
+ company_id: str, task_id: str, identity: Identity, allow_locked: bool
+):
+ task = get_task_with_write_access(
+ task_id=task_id,
+ company_id=company_id,
+ identity=identity,
+ only=("id", "status"),
+ )
+ if not allow_locked and task.status in LOCKED_TASK_STATUSES:
+ raise errors.bad_request.InvalidTaskId(
+ replacement_msg="Cannot update events for a published task",
+ company=company_id,
+ id=task_id,
+ )
+
+
@endpoint("events.delete_for_task")
def delete_for_task(call, company_id, request: TaskRequest):
task_id = request.task
allow_locked = call.data.get("allow_locked", False)
- get_task_with_write_access(
- task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
+ _validate_task_for_events_update(
+ company_id=company_id,
+ task_id=task_id,
+ identity=call.identity,
+ allow_locked=allow_locked,
)
+
call.result.data = dict(
- deleted=event_bll.delete_task_events(
- company_id, task_id, allow_locked=allow_locked
- )
+ deleted=event_bll.delete_task_events(company_id, task_id, wait_for_delete=True)
)
+def _validate_model_for_events_update(
+ company_id: str, model_id: str, allow_locked: bool
+):
+ model = model_bll.assert_exists(company_id, model_id, only=("id", "ready"))[0]
+ if not allow_locked and model.ready:
+ raise errors.bad_request.InvalidModelId(
+ replacement_msg="Cannot update events for a published model",
+ company=company_id,
+ id=model_id,
+ )
+
+
@endpoint("events.delete_for_model")
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
model_id = request.model
allow_locked = call.data.get("allow_locked", False)
- model_bll.assert_exists(company_id, model_id, return_models=False)
+ _validate_model_for_events_update(
+ company_id=company_id, model_id=model_id, allow_locked=allow_locked
+ )
+
call.result.data = dict(
deleted=event_bll.delete_task_events(
- company_id, model_id, allow_locked=allow_locked, model=True
+ company_id, model_id, model=True, wait_for_delete=True
)
)
@@ -1033,14 +1071,17 @@ def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
- get_task_with_write_access(
- task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
+ _validate_task_for_events_update(
+ company_id=company_id,
+ task_id=task_id,
+ identity=call.identity,
+ allow_locked=request.allow_locked,
)
+
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,
task_id=task_id,
- allow_locked=request.allow_locked,
threshold_sec=request.threshold_sec,
exclude_metrics=request.exclude_metrics,
include_metrics=request.include_metrics,
diff --git a/apiserver/services/models.py b/apiserver/services/models.py
index b476bd6..c612bc9 100644
--- a/apiserver/services/models.py
+++ b/apiserver/services/models.py
@@ -27,12 +27,17 @@ from apiserver.apimodels.models import (
UpdateModelRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
+from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
+from apiserver.bll.task.task_cleanup import (
+ schedule_for_delete,
+ delete_task_events_and_collect_urls,
+)
from apiserver.bll.task.task_operations import publish_task
-from apiserver.bll.task.utils import get_task_with_write_access
+from apiserver.bll.task.utils import get_task_with_write_access, deleted_prefix
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.model import validate_id
@@ -64,6 +69,7 @@ from apiserver.services.utils import (
log = config.logger(__file__)
org_bll = OrgBLL()
project_bll = ProjectBLL()
+event_bll = EventBLL()
def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
@@ -182,7 +188,12 @@ def get_all(call: APICall, company_id, _):
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
call.result.data = {
"frameworks": sorted(
- project_bll.get_model_frameworks(company_id, project_ids=request.projects)
+ filter(
+ None,
+ project_bll.get_model_frameworks(
+ company_id, project_ids=request.projects
+ ),
+ )
)
}
@@ -216,6 +227,9 @@ last_update_fields = (
def parse_model_fields(call, valid_fields):
+ task_id = call.data.get("task")
+ if isinstance(task_id, str) and task_id.startswith(deleted_prefix):
+ call.data.pop("task")
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
@@ -555,16 +569,67 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
)
+def _delete_model_events(
+ company_id: str,
+ user_id: str,
+ models: Sequence[Model],
+ delete_external_artifacts: bool,
+ sync_delete: bool,
+):
+ if not models:
+ return
+ model_ids = [m.id for m in models]
+ delete_external_artifacts = delete_external_artifacts and config.get(
+ "services.async_urls_delete.enabled", True
+ )
+ if delete_external_artifacts:
+ model_urls = {m.uri for m in models if m.uri}
+ if model_urls:
+ schedule_for_delete(
+ task_id=model_ids[0],
+ company=company_id,
+ user=user_id,
+ urls=model_urls,
+ can_delete_folders=False,
+ )
+
+ event_urls = delete_task_events_and_collect_urls(
+ company=company_id,
+ task_ids=model_ids,
+ model=True,
+ wait_for_delete=sync_delete,
+ )
+ if event_urls:
+ schedule_for_delete(
+ task_id=model_ids[0],
+ company=company_id,
+ user=user_id,
+ urls=event_urls,
+ can_delete_folders=False,
+ )
+
+ event_bll.delete_task_events(
+ company_id, model_ids, model=True, wait_for_delete=sync_delete
+ )
+
+
@endpoint("models.delete", request_data_model=DeleteModelRequest)
def delete(call: APICall, company_id, request: DeleteModelRequest):
+ user_id = call.identity.user
del_count, model = ModelBLL.delete_model(
model_id=request.model,
company_id=company_id,
- user_id=call.identity.user,
+ user_id=user_id,
force=request.force,
- delete_external_artifacts=request.delete_external_artifacts,
)
if del_count:
+ _delete_model_events(
+ company_id=company_id,
+ user_id=user_id,
+ models=[model],
+ delete_external_artifacts=request.delete_external_artifacts,
+ sync_delete=True,
+ )
_reset_cached_tags(
company_id, projects=[model.project] if model.project else []
)
@@ -577,27 +642,38 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
request_data_model=ModelsDeleteManyRequest,
response_data_model=BatchResponse,
)
-def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
+def delete_many(call: APICall, company_id, request: ModelsDeleteManyRequest):
+ user_id = call.identity.user
+
results, failures = run_batch_operation(
func=partial(
ModelBLL.delete_model,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
- delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
- if results:
- projects = set(model.project for _, (_, model) in results)
+ succeeded = []
+ deleted_models = []
+ for _id, (deleted, model) in results:
+ succeeded.append(dict(id=_id, deleted=bool(deleted), url=model.uri))
+ deleted_models.append(model)
+
+ if deleted_models:
+ _delete_model_events(
+ company_id=company_id,
+ user_id=user_id,
+ models=deleted_models,
+ delete_external_artifacts=request.delete_external_artifacts,
+ sync_delete=False,
+ )
+ projects = set(model.project for model in deleted_models)
_reset_cached_tags(company_id, projects=list(projects))
call.result.data_model = BatchResponse(
- succeeded=[
- dict(id=_id, deleted=bool(deleted), url=model.uri)
- for _id, (deleted, model) in results
- ],
+ succeeded=succeeded,
failed=failures,
)
@@ -684,10 +760,11 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("models.update_tags")
-def update_tags(_, company_id: str, request: UpdateTagsRequest):
+def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
+ user_id=call.identity.user,
entity_cls=Model,
entity_ids=request.ids,
add_tags=request.add_tags,
diff --git a/apiserver/services/organization.py b/apiserver/services/organization.py
index 3f3868b..8c3e809 100644
--- a/apiserver/services/organization.py
+++ b/apiserver/services/organization.py
@@ -301,7 +301,7 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque
future = pool.submit(get_fn, page, min(page_size, items_left))
with StringIO() as fp:
- writer = csv.writer(fp)
+ writer = csv.writer(fp, quoting=csv.QUOTE_NONNUMERIC)
if page == 1:
fp.write("\ufeff") # utf-8 signature
writer.writerow(field_mappings)
diff --git a/apiserver/services/pipelines.py b/apiserver/services/pipelines.py
index 6791ac8..6efd10b 100644
--- a/apiserver/services/pipelines.py
+++ b/apiserver/services/pipelines.py
@@ -1,8 +1,6 @@
import re
from functools import partial
-import attr
-
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
from apiserver.apimodels.pipelines import (
StartPipelineRequest,
@@ -18,6 +16,7 @@ from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities.dicts import nested_get
+from .tasks import _delete_task_events
org_bll = OrgBLL()
project_bll = ProjectBLL()
@@ -62,21 +61,31 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
identity=call.identity,
move_to_trash=False,
force=True,
- return_file_urls=False,
delete_output_models=True,
status_message="",
status_reason="Pipeline run deleted",
- delete_external_artifacts=True,
include_pipeline_steps=True,
),
ids=list(ids),
)
succeeded = []
+ tasks = {}
if results:
for _id, (deleted, task, cleanup_res) in results:
+ if deleted:
+ tasks[_id] = cleanup_res
succeeded.append(
- dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
+ dict(id=_id, deleted=bool(deleted), **cleanup_res.to_res_dict(False))
+ )
+
+ if tasks:
+ _delete_task_events(
+ company_id=company_id,
+ user_id=call.identity.user,
+ tasks=tasks,
+ delete_external_artifacts=True,
+ sync_delete=True,
)
call.result.data = dict(succeeded=succeeded, failed=failures)
@@ -101,7 +110,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
- hyperparams=hyperparams,
+ hyperparams_overrides=hyperparams,
)
_update_task_name(task)
diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py
index 5d45519..2510151 100644
--- a/apiserver/services/projects.py
+++ b/apiserver/services/projects.py
@@ -370,6 +370,7 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
delete_external_artifacts=request.delete_external_artifacts,
)
_reset_cached_tags(company_id, projects=list(affected_projects))
+ # noinspection PyTypeChecker
call.result.data = {**attr.asdict(res)}
diff --git a/apiserver/services/queues.py b/apiserver/services/queues.py
index c9fdf36..73c291f 100644
--- a/apiserver/services/queues.py
+++ b/apiserver/services/queues.py
@@ -20,13 +20,15 @@ from apiserver.apimodels.queues import (
GetNextTaskRequest,
GetByIdRequest,
GetAllRequest,
+ AddTaskRequest,
+ RemoveTaskRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.queue.queue_bll import MOVE_FIRST, MOVE_LAST
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
-from apiserver.database.model.task.task import Task
+from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
@@ -46,7 +48,7 @@ def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
unescape_metadata(call, queue_data)
-@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
+@endpoint("queues.get_by_id", min_version="2.4")
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
@@ -111,7 +113,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest):
call.result.data = {"queues": queues, **ret_params}
-@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
+@endpoint("queues.create", min_version="2.4")
def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags(
call, request.tags, request.system_tags, validate=True
@@ -119,6 +121,7 @@ def create(call: APICall, company_id, request: CreateRequest):
queue = queue_bll.create(
company_id=company_id,
name=request.name,
+ display_name=request.display_name,
tags=tags,
system_tags=system_tags,
metadata=Metadata.metadata_from_api(request.metadata),
@@ -129,60 +132,82 @@ def create(call: APICall, company_id, request: CreateRequest):
@endpoint(
"queues.update",
min_version="2.4",
- request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
)
-def update(call: APICall, company_id, req_model: UpdateRequest):
+def update(call: APICall, company_id, request: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data, validate=True)
escape_metadata(data)
updated, fields = queue_bll.update(
- company_id=company_id, queue_id=req_model.queue, **data
+ company_id=company_id, queue_id=request.queue, **data
)
conform_queue_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
-@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
-def delete(call: APICall, company_id, req_model: DeleteRequest):
+@endpoint("queues.delete", min_version="2.4")
+def delete(call: APICall, company_id, request: DeleteRequest):
queue_bll.delete(
company_id=company_id,
user_id=call.identity.user,
- queue_id=req_model.queue,
- force=req_model.force,
+ queue_id=request.queue,
+ force=request.force,
)
call.result.data = {"deleted": 1}
-@endpoint("queues.add_task", min_version="2.4", request_data_model=TaskRequest)
-def add_task(call: APICall, company_id, req_model: TaskRequest):
- call.result.data = {
- "added": queue_bll.add_task(
- company_id=company_id, queue_id=req_model.queue, task_id=req_model.task
+@endpoint("queues.add_task", min_version="2.4")
+def add_task(call: APICall, company_id, request: AddTaskRequest):
+ added = queue_bll.add_task(
+ company_id=company_id, queue_id=request.queue, task_id=request.task
+ )
+ if added and request.update_execution_queue:
+ Task.objects(id=request.task).update(
+ execution__queue=request.queue, multi=False
)
- }
+ call.result.data = {"added": added}
-@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
+@endpoint("queues.get_next_task")
def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=request.queue, task_id=request.task
)
if entry:
data = {"entry": entry.to_proper_dict()}
- if request.get_task_info:
- task = Task.objects(id=entry.task).only("company", "user").first()
- if task:
+ task = Task.objects(id=entry.task).only("company", "user", "status").first()
+ if task:
+ # fix racing condition that can result in the task being aborted
+ # by an agent after it was already placed in a queue
+ if task.status == TaskStatus.stopped:
+ task.update(status=TaskStatus.queued)
+
+ if request.get_task_info:
data["task_info"] = {"company": task.company, "user": task.user}
call.result.data = data
-@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
-def remove_task(call: APICall, company_id, req_model: TaskRequest):
+@endpoint("queues.remove_task", min_version="2.4")
+def remove_task(call: APICall, company_id, request: RemoveTaskRequest):
call.result.data = {
"removed": queue_bll.remove_task(
- company_id=company_id, queue_id=req_model.queue, task_id=req_model.task
+ company_id=company_id,
+ user_id=call.identity.user,
+ queue_id=request.queue,
+ task_id=request.task,
+ update_task_status=request.update_task_status,
+ )
+ }
+
+
+@endpoint("queues.clear_queue")
+def clear_queue(call: APICall, company_id, request: QueueRequest):
+ call.result.data = {
+ "removed_tasks": queue_bll.clear_queue(
+ company_id=company_id,
+ user_id=call.identity.user,
+ queue_id=request.queue,
)
}
@@ -190,16 +215,15 @@ def remove_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.move_task_forward",
min_version="2.4",
- request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
-def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
+def move_task_forward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
- queue_id=req_model.queue,
- task_id=req_model.task,
- move_count=-req_model.count,
+ queue_id=request.queue,
+ task_id=request.task,
+ move_count=-request.count,
)
)
@@ -207,16 +231,15 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint(
"queues.move_task_backward",
min_version="2.4",
- request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
-def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
+def move_task_backward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
- queue_id=req_model.queue,
- task_id=req_model.task,
- move_count=req_model.count,
+ queue_id=request.queue,
+ task_id=request.task,
+ move_count=request.count,
)
)
@@ -224,15 +247,14 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint(
"queues.move_task_to_front",
min_version="2.4",
- request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
-def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
+def move_task_to_front(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
- queue_id=req_model.queue,
- task_id=req_model.task,
+ queue_id=request.queue,
+ task_id=request.task,
move_count=MOVE_FIRST,
)
)
@@ -241,15 +263,14 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.move_task_to_back",
min_version="2.4",
- request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
-def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
+def move_task_to_back(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
- queue_id=req_model.queue,
- task_id=req_model.task,
+ queue_id=request.queue,
+ task_id=request.task,
move_count=MOVE_LAST,
)
)
@@ -258,7 +279,6 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.get_queue_metrics",
min_version="2.4",
- request_data_model=GetMetricsRequest,
response_data_model=GetMetricsResponse,
)
def get_queue_metrics(
diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py
index de4ba60..9799f12 100644
--- a/apiserver/services/reports.py
+++ b/apiserver/services/reports.py
@@ -282,6 +282,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
metric_variants=_get_metric_variants_from_request(
request.scalar_metrics_iter_histogram.metrics
),
+ model_events=request.model_events,
)
if request.single_value_metrics:
diff --git a/apiserver/services/serving.py b/apiserver/services/serving.py
new file mode 100644
index 0000000..56455e8
--- /dev/null
+++ b/apiserver/services/serving.py
@@ -0,0 +1,69 @@
+from apiserver.apimodels.serving import (
+ RegisterRequest,
+ UnregisterRequest,
+ StatusReportRequest,
+ GetEndpointDetailsRequest,
+ GetEndpointMetricsHistoryRequest,
+)
+from apiserver.apierrors import errors
+from apiserver.service_repo import endpoint, APICall
+from apiserver.bll.serving import ServingBLL, ServingStats
+
+
+serving_bll = ServingBLL()
+
+
+@endpoint("serving.register_container")
+def register_container(call: APICall, company: str, request: RegisterRequest):
+ serving_bll.register_serving_container(
+ company_id=company, ip=call.real_ip, request=request
+ )
+
+
+@endpoint("serving.unregister_container")
+def unregister_container(_: APICall, company: str, request: UnregisterRequest):
+ serving_bll.unregister_serving_container(
+ company_id=company, container_id=request.container_id
+ )
+
+
+@endpoint("serving.container_status_report")
+def container_status_report(call: APICall, company: str, request: StatusReportRequest):
+ if not request.endpoint_url:
+ raise errors.bad_request.ValidationError(
+ "Missing required field 'endpoint_url'"
+ )
+ serving_bll.container_status_report(
+ company_id=company,
+ ip=call.real_ip,
+ report=request,
+ )
+
+
+@endpoint("serving.get_endpoints")
+def get_endpoints(call: APICall, company: str, _):
+ call.result.data = {"endpoints": serving_bll.get_endpoints(company)}
+
+
+@endpoint("serving.get_loading_instances")
+def get_loading_instances(call: APICall, company: str, _):
+ call.result.data = {"instances": serving_bll.get_loading_instances(company)}
+
+
+@endpoint("serving.get_endpoint_details")
+def get_endpoint_details(
+ call: APICall, company: str, request: GetEndpointDetailsRequest
+):
+ call.result.data = serving_bll.get_endpoint_details(
+ company_id=company, endpoint_url=request.endpoint_url
+ )
+
+
+@endpoint("serving.get_endpoint_metrics_history")
+def get_endpoint_metrics_history(
+ call: APICall, company: str, request: GetEndpointMetricsHistoryRequest
+):
+ call.result.data = ServingStats.get_endpoint_metrics(
+ company_id=company,
+ metrics_request=request,
+ )
diff --git a/apiserver/services/storage.py b/apiserver/services/storage.py
new file mode 100644
index 0000000..d20b362
--- /dev/null
+++ b/apiserver/services/storage.py
@@ -0,0 +1,22 @@
+from apiserver.apimodels.storage import ResetSettingsRequest, SetSettingsRequest
+from apiserver.bll.storage import StorageBLL
+from apiserver.service_repo import endpoint, APICall
+
+storage_bll = StorageBLL()
+
+
+@endpoint("storage.get_settings")
+def get_settings(call: APICall, company: str, _):
+ call.result.data = {"settings": storage_bll.get_company_settings(company)}
+
+
+@endpoint("storage.set_settings")
+def set_settings(call: APICall, company: str, request: SetSettingsRequest):
+ call.result.data = {"updated": storage_bll.set_company_settings(company, request)}
+
+
+@endpoint("storage.reset_settings")
+def reset_settings(call: APICall, company: str, request: ResetSettingsRequest):
+ call.result.data = {
+ "updated": storage_bll.reset_company_settings(company, request.keys)
+ }
diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py
index 7d02662..eb32f5e 100644
--- a/apiserver/services/tasks.py
+++ b/apiserver/services/tasks.py
@@ -1,9 +1,9 @@
+import itertools
from copy import deepcopy
from datetime import datetime
from functools import partial
-from typing import Sequence, Union, Tuple
+from typing import Sequence, Union, Tuple, Mapping
-import attr
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
@@ -80,6 +80,11 @@ from apiserver.bll.task import (
TaskBLL,
ChangeStatusRequest,
)
+from apiserver.bll.task.task_cleanup import (
+ delete_task_events_and_collect_urls,
+ schedule_for_delete,
+ CleanupResult,
+)
from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save,
artifacts_unprepare_from_saved,
@@ -109,6 +114,7 @@ from apiserver.bll.task.utils import (
get_task_with_write_access,
)
from apiserver.bll.util import run_batch_operation, update_project_time
+from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
@@ -295,9 +301,7 @@ def get_types(call: APICall, company_id, request: GetTypesRequest):
}
-@endpoint(
- "tasks.stop", response_data_model=UpdateResponse
-)
+@endpoint("tasks.stop", response_data_model=UpdateResponse)
def stop(call: APICall, company_id, request: StopRequest):
"""
stop
@@ -919,6 +923,11 @@ def delete_configuration(
response_data_model=EnqueueResponse,
)
def enqueue(call: APICall, company_id, request: EnqueueRequest):
+ if request.verify_watched_queue and not request.update_execution_queue:
+ raise errors.bad_request.ValidationError(
+ "verify_watched_queue cannot be used with update_execution_queue=False"
+ )
+
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
@@ -928,6 +937,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
status_reason=request.status_reason,
queue_name=request.queue_name,
force=request.force,
+ update_execution_queue=request.update_execution_queue,
)
if request.verify_watched_queue:
res_queue = nested_get(res, ("fields", "execution.queue"))
@@ -1016,21 +1026,98 @@ def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
)
+def _delete_task_events(
+ company_id: str,
+ user_id: str,
+ tasks: Mapping[str, CleanupResult],
+ delete_external_artifacts: bool,
+ sync_delete: bool,
+):
+ if not tasks:
+ return
+ task_ids = list(tasks)
+ deleted_model_ids = set(
+ itertools.chain.from_iterable(
+ cr.deleted_model_ids for cr in tasks.values() if cr.deleted_model_ids
+ )
+ )
+
+ delete_external_artifacts = delete_external_artifacts and config.get(
+ "services.async_urls_delete.enabled", True
+ )
+ if delete_external_artifacts:
+ for t_id, cleanup_res in tasks.items():
+ urls = set(cleanup_res.urls.model_urls) | set(
+ cleanup_res.urls.artifact_urls
+ )
+ if urls:
+ schedule_for_delete(
+ task_id=t_id,
+ company=company_id,
+ user=user_id,
+ urls=urls,
+ can_delete_folders=False,
+ )
+
+ event_urls = delete_task_events_and_collect_urls(
+ company=company_id,
+ task_ids=task_ids,
+ wait_for_delete=sync_delete,
+ )
+ if deleted_model_ids:
+ event_urls.update(
+ delete_task_events_and_collect_urls(
+ company=company_id,
+ task_ids=list(deleted_model_ids),
+ model=True,
+ wait_for_delete=sync_delete,
+ )
+ )
+
+ if event_urls:
+ schedule_for_delete(
+ task_id=task_ids[0],
+ company=company_id,
+ user=user_id,
+ urls=event_urls,
+ can_delete_folders=False,
+ )
+ else:
+ event_bll.delete_task_events(company_id, task_ids, wait_for_delete=sync_delete)
+ if deleted_model_ids:
+ event_bll.delete_task_events(
+ company_id,
+ list(deleted_model_ids),
+ model=True,
+ wait_for_delete=sync_delete,
+ )
+
+
@endpoint(
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
)
def reset(call: APICall, company_id, request: ResetRequest):
+ task_id = request.task
dequeued, cleanup_res, updates = reset_task(
- task_id=request.task,
+ task_id=task_id,
company_id=company_id,
identity=call.identity,
force=request.force,
- return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
- delete_external_artifacts=request.delete_external_artifacts,
)
- res = ResetResponse(**updates, **attr.asdict(cleanup_res), dequeued=dequeued)
+ _delete_task_events(
+ company_id=company_id,
+ user_id=call.identity.user,
+ tasks={task_id: cleanup_res},
+ delete_external_artifacts=request.delete_external_artifacts,
+ sync_delete=True,
+ )
+ res = ResetResponse(
+ **updates,
+ **cleanup_res.to_res_dict(request.return_file_urls),
+ dequeued=dequeued,
+ )
call.result.data_model = res
@@ -1046,25 +1133,33 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
company_id=company_id,
identity=call.identity,
force=request.force,
- return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
- delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
succeeded = []
+ tasks = {}
for _id, (dequeued, cleanup, res) in results:
+ tasks[_id] = cleanup
succeeded.append(
ResetBatchItem(
id=_id,
dequeued=bool(dequeued.get("removed")) if dequeued else False,
- **attr.asdict(cleanup),
+ **cleanup.to_res_dict(request.return_file_urls),
**res,
)
)
+ _delete_task_events(
+ company_id=company_id,
+ user_id=call.identity.user,
+ tasks=tasks,
+ delete_external_artifacts=request.delete_external_artifacts,
+ sync_delete=False,
+ )
+
call.result.data_model = ResetManyResponse(
succeeded=succeeded,
failed=failures,
@@ -1160,16 +1255,23 @@ def delete(call: APICall, company_id, request: DeleteRequest):
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
- return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
- delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
)
if deleted:
+ _delete_task_events(
+ company_id=company_id,
+ user_id=call.identity.user,
+ tasks={request.task: cleanup_res},
+ delete_external_artifacts=request.delete_external_artifacts,
+ sync_delete=True,
+ )
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
- call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
+ call.result.data = dict(
+ deleted=bool(deleted), **cleanup_res.to_res_dict(request.return_file_urls)
+ )
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
@@ -1181,25 +1283,42 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
- return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
- delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
+ succeeded = []
+ tasks = {}
if results:
- projects = set(task.project for _, (_, task, _) in results)
+ projects = set()
+ for _id, (deleted, task, cleanup_res) in results:
+ if deleted:
+ projects.add(task.project)
+ tasks[_id] = cleanup_res
+ succeeded.append(
+ dict(
+ id=_id,
+ deleted=bool(deleted),
+ **cleanup_res.to_res_dict(request.return_file_urls),
+ )
+ )
+
+ if tasks:
+ _delete_task_events(
+ company_id=company_id,
+ user_id=call.identity.user,
+ tasks=tasks,
+ delete_external_artifacts=request.delete_external_artifacts,
+ sync_delete=False,
+ )
_reset_cached_tags(company_id, projects=list(projects))
call.result.data = dict(
- succeeded=[
- dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
- for _id, (deleted, _, cleanup_res) in results
- ],
+ succeeded=succeeded,
failed=failures,
)
@@ -1359,10 +1478,11 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required"
)
- _assert_writable_tasks(company_id, call.identity, request.ids)
- updated_projects = set(
- t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
+ tasks = _assert_writable_tasks(
+ company_id, call.identity, request.ids, only=("id", "project")
)
+ updated_projects = set(t.project for t in tasks if t.project)
+
project_id = project_bll.move_under_project(
entity_cls=Task,
user=call.identity.user,
@@ -1385,6 +1505,7 @@ def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
+ user_id=call.identity.user,
entity_cls=Task,
entity_ids=request.ids,
add_tags=request.add_tags,
diff --git a/apiserver/tests/automated/test_models.py b/apiserver/tests/automated/test_models.py
index 36dd7c5..7e782b7 100644
--- a/apiserver/tests/automated/test_models.py
+++ b/apiserver/tests/automated/test_models.py
@@ -1,3 +1,5 @@
+import unittest
+
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.tests.automated import TestService
@@ -236,6 +238,23 @@ class TestModelsService(TestService):
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual([], res.frameworks)
+ @unittest.skip(
+ """This test requires the following setting
+ CLEARML__services__async_urls_delete__fileserver__url_prefixes=["https://files.allegro-master.hosted.allegro.ai"
+ Check the test results in the logs of async_delete service
+ """
+ )
+ def test_delete_many_with_files(self):
+ models = [
+ self._create_model(
+ name=f"delete model test{idx}",
+ uri=f"https://files.allegro-master.hosted.allegro.ai/models/test{idx}.txt"
+ )
+ for idx in range(2)
+ ]
+ self.api.models.delete_many(ids=models)
+
+
def test_make_public(self):
m1 = self._create_model(name="public model test")
@@ -277,7 +296,7 @@ class TestModelsService(TestService):
service="models",
delete_params=dict(can_fail=True, force=True),
name=kwargs.pop("name", "test"),
- uri=kwargs.pop("name", "file:///a"),
+ uri=kwargs.pop("uri", "file:///a"),
labels=kwargs.pop("labels", {}),
**kwargs,
)
diff --git a/apiserver/tests/automated/test_pipelines.py b/apiserver/tests/automated/test_pipelines.py
index 3c1e0f9..ec75144 100644
--- a/apiserver/tests/automated/test_pipelines.py
+++ b/apiserver/tests/automated/test_pipelines.py
@@ -5,6 +5,18 @@ from apiserver.tests.automated import TestService
class TestPipelines(TestService):
+ task_hyperparams = {
+ "properties":
+ {
+ "version": {
+ "section": "properties",
+ "name": "version",
+ "type": "str",
+ "value": "3.2"
+ }
+ }
+ }
+
def test_controller_operations(self):
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
@@ -82,14 +94,17 @@ class TestPipelines(TestService):
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
- pipeline.hyperparams.Args,
+ pipeline.hyperparams,
{
- a["name"]: {
- "section": "Args",
- "name": a["name"],
- "value": a["value"],
- }
- for a in args
+ "Args": {
+ a["name"]: {
+ "section": "Args",
+ "name": a["name"],
+ "value": a["value"],
+ }
+ for a in args
+ },
+ **self.task_hyperparams,
},
)
@@ -124,6 +139,7 @@ class TestPipelines(TestService):
type="controller",
project=project,
system_tags=["pipeline"],
+ hyperparams=self.task_hyperparams,
),
)
diff --git a/apiserver/tests/automated/test_queues.py b/apiserver/tests/automated/test_queues.py
index 1959d91..9ba30aa 100644
--- a/apiserver/tests/automated/test_queues.py
+++ b/apiserver/tests/automated/test_queues.py
@@ -40,6 +40,53 @@ class TestQueues(TestService):
)
self.assertMetricQueues(res["queues"], queue_id)
+ def test_add_remove_clear(self):
+ queue1 = self._temp_queue("TestTempQueue1")
+ queue2 = self._temp_queue("TestTempQueue2")
+
+ task_names = ["TempDevTask1", "TempDevTask2"]
+ tasks = [self._temp_task(name) for name in task_names]
+
+ for task in tasks:
+ self.api.tasks.enqueue(task=task, queue=queue1)
+
+ # remove task with and without status update
+ res = self.api.queues.remove_task(task=tasks[0], queue=queue1)
+ self.assertEqual(res.removed, 1)
+ res = self.api.tasks.get_by_id(task=tasks[0])
+ self.assertEqual(res.task.status, "queued")
+ self.assertEqual(res.task.execution.queue, queue1)
+
+ res = self.api.queues.remove_task(task=tasks[1], queue=queue1, update_task_status=True)
+ self.assertEqual(res.removed, 1)
+ res = self.api.tasks.get_by_id(task=tasks[1])
+ self.assertEqual(res.task.status, "created")
+
+ res = self.api.queues.get_by_id(queue=queue1)
+ self.assertQueueTasks(res.queue, [])
+
+ # add task
+ res = self.api.queues.add_task(queue=queue2, task=tasks[0])
+ self.assertEqual(res.added, 1)
+ res = self.api.tasks.get_by_id(task=tasks[0])
+ self.assertEqual(res.task.status, "queued")
+ self.assertEqual(res.task.execution.queue, queue2)
+
+ res = self.api.queues.get_by_id(queue=queue2)
+ self.assertQueueTasks(res.queue, [tasks[0]])
+
+ # clear queue
+ res = self.api.queues.clear_queue(queue=queue1)
+ self.assertEqual(res.removed_tasks, [])
+ res = self.api.queues.clear_queue(queue=queue2)
+ self.assertEqual(res.removed_tasks, [tasks[0]])
+
+ res = self.api.tasks.get_by_id(task=tasks[0])
+ self.assertEqual(res.task.status, "created")
+
+ res = self.api.queues.get_by_id(queue=queue2)
+ self.assertQueueTasks(res.queue, [])
+
def test_hidden_queues(self):
hidden_name = "TestHiddenQueue"
hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"])
@@ -212,12 +259,19 @@ class TestQueues(TestService):
def test_get_all_ex(self):
queue_name = "TestTempQueue1"
+ queue_display_name = "Test display name"
queue_tags = ["Test1", "Test2"]
- queue = self._temp_queue(queue_name, tags=queue_tags)
+ queue = self._temp_queue(queue_name, display_name=queue_display_name, tags=queue_tags)
res = self.api.queues.get_all_ex(name="TestTempQueue*").queues
self.assertQueue(
- res, queue_id=queue, name=queue_name, tags=queue_tags, tasks=[], workers=[]
+ res,
+ queue_id=queue,
+ display_name=queue_display_name,
+ name=queue_name,
+ tags=queue_tags,
+ tasks=[],
+ workers=[],
)
tasks = [
@@ -232,6 +286,7 @@ class TestQueues(TestService):
res,
queue_id=queue,
name=queue_name,
+ display_name=queue_display_name,
tags=queue_tags,
tasks=tasks,
workers=workers,
@@ -259,6 +314,7 @@ class TestQueues(TestService):
queues: Sequence[AttrDict],
queue_id: str,
name: str,
+ display_name: str,
tags: Sequence[str],
tasks: Sequence[dict],
workers: Sequence[dict],
@@ -267,15 +323,33 @@ class TestQueues(TestService):
assert queue.last_update
self.assertEqualNoOrder(queue.tags, tags)
self.assertEqual(queue.name, name)
- self.assertQueueTasks(queue, tasks)
- self.assertQueueWorkers(queue, workers)
+ self.assertEqual(queue.display_name, display_name)
+ self.assertQueueTasks(queue, tasks, name, display_name)
+ self.assertQueueWorkers(queue, workers, name, display_name)
def assertTaskTags(self, task, system_tags):
res = self.api.tasks.get_by_id(task=task)
self.assertSequenceEqual(res.task.system_tags, system_tags)
- def assertQueueTasks(self, queue: AttrDict, tasks: Sequence):
+ def assertQueueTasks(
+ self,
+ queue: AttrDict,
+ tasks: Sequence,
+ queue_name: str = None,
+ display_queue_name: str = None,
+ ):
self.assertEqual([e.task for e in queue.entries], tasks)
+ if queue_name:
+ for task in tasks:
+ execution = self.api.tasks.get_by_id_ex(
+ id=[task["id"]],
+ only_fields=[
+ "execution.queue.name",
+ "execution.queue.display_name",
+ ],
+ ).tasks[0].execution
+ self.assertEqual(execution.queue.name, queue_name)
+ self.assertEqual(execution.queue.display_name, display_queue_name)
def assertGetNextTasks(self, queue, tasks):
for task_id in tasks:
@@ -283,11 +357,28 @@ class TestQueues(TestService):
self.assertEqual(res.entry.task, task_id)
assert not self.api.queues.get_next_task(queue=queue)
- def assertQueueWorkers(self, queue: AttrDict, workers: Sequence[dict]):
+ def assertQueueWorkers(
+ self,
+ queue: AttrDict,
+ workers: Sequence[dict],
+ queue_name: str = None,
+ display_queue_name: str = None,
+ ):
sort_key = itemgetter("name")
self.assertEqual(
sorted(queue.workers, key=sort_key), sorted(workers, key=sort_key)
)
+ if not workers:
+ return
+
+ res = self.api.workers.get_all()
+ worker_ids = {w["key"] for w in workers}
+ found = [w for w in res.workers if w.key in worker_ids]
+ self.assertEqual(len(found), len(worker_ids))
+ for worker in found:
+ for queue in worker.queues:
+ self.assertEqual(queue.name, queue_name)
+ self.assertEqual(queue.display_name, display_queue_name)
def _temp_queue(self, queue_name, **kwargs):
return self.create_temp("queues", name=queue_name, **kwargs)
diff --git a/apiserver/tests/automated/test_reports.py b/apiserver/tests/automated/test_reports.py
index 654391f..1b99dcb 100644
--- a/apiserver/tests/automated/test_reports.py
+++ b/apiserver/tests/automated/test_reports.py
@@ -12,7 +12,7 @@ class TestReports(TestService):
def _delete_project(self, name):
existing_project = first(
self.api.projects.get_all_ex(
- name=f"^{re.escape(name)}$", search_hidden=True
+ name=f"^{re.escape(name)}$", search_hidden=True, allow_public=False
).projects
)
if existing_project:
@@ -34,10 +34,10 @@ class TestReports(TestService):
self.assertEqual(set(task.tags), set(tags))
self.assertEqual(task.type, "report")
self.assertEqual(set(task.system_tags), {"hidden", "reports"})
- projects = self.api.projects.get_all_ex(name=r"^\.reports$").projects
+ projects = self.api.projects.get_all_ex(name=r"^\.reports$", allow_public=False).projects
self.assertEqual(len(projects), 0)
project = self.api.projects.get_all_ex(
- name=r"^\.reports$", search_hidden=True
+ name=r"^\.reports$", search_hidden=True, allow_public=False
).projects[0]
self.assertEqual(project.id, task.project.id)
self.assertEqual(set(project.system_tags), {"hidden", "reports"})
@@ -108,6 +108,7 @@ class TestReports(TestService):
include_stats=True,
check_own_contents=True,
search_hidden=True,
+ allow_public=False,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
@@ -120,6 +121,7 @@ class TestReports(TestService):
include_stats=True,
check_own_contents=True,
search_hidden=True,
+ allow_public=False,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
diff --git a/apiserver/tests/automated/test_serving.py b/apiserver/tests/automated/test_serving.py
new file mode 100644
index 0000000..b24f5b7
--- /dev/null
+++ b/apiserver/tests/automated/test_serving.py
@@ -0,0 +1,124 @@
+from time import time, sleep
+
+from apiserver.apierrors import errors
+from apiserver.tests.automated import TestService
+
+
+class TestServing(TestService):
+ def test_status_report(self):
+ container_id1 = "container_1"
+ container_id2 = "container_2"
+ url = "http://test_url"
+ reference = [
+ {"type": "app_id", "value": "test"},
+ {"type": "app_instance", "value": "abd478c8"},
+ {"type": "model", "value": "262829d3"},
+ {"type": "model", "value": "7ea29c04"},
+ ]
+ container_infos = [
+ {
+ "container_id": container_id, # required
+ "endpoint_name": "my endpoint", # required
+ "endpoint_url": url, # can be omitted for register but required for status report
+ "model_name": "my model", # required
+ "model_source": "s3//my_bucket", # optional right now
+ "model_version": "3.1.0", # optional right now
+ "preprocess_artifact": "some string here", # optional right now
+ "input_type": "another string here", # optional right now
+ "input_size": 9_000_000, # optional right now, bytes
+ "tags": ["tag1", "tag2"], # optional
+ "system_tags": None, # optional
+ **({"reference": reference} if container_id == container_id1 else {}),
+ }
+ for container_id in (container_id1, container_id2)
+ ]
+
+ # registering instances
+ for container_info in container_infos:
+ self.api.serving.register_container(
+ **container_info,
+ timeout=100, # expiration timeout in seconds. Optional, the default value is 600
+ )
+ for idx, container_info in enumerate(container_infos):
+ mul = idx + 1
+ self.api.serving.container_status_report(
+ **container_info,
+ uptime_sec=1000 * mul,
+ requests_num=1000 * mul,
+ requests_min=5 * mul, # requests per minute
+ latency_ms=100 * mul, # average latency
+ machine_stats={ # the same structure here as used by worker status_reports
+ "cpu_usage": [10, 20],
+ "memory_used": 50 * 1024,
+ },
+ )
+
+ # getting endpoints and endpoint details
+ endpoints = self.api.serving.get_endpoints().endpoints
+ self.assertTrue(any(e for e in endpoints if e.url == url))
+ details = self.api.serving.get_endpoint_details(endpoint_url=url)
+ self.assertEqual(details.url, url)
+ self.assertEqual(details.uptime_sec, 2000)
+ self.assertEqual(
+ {
+ inst.id: [
+ inst[field]
+ for field in (
+ "uptime_sec",
+ "requests",
+ "requests_min",
+ "latency_ms",
+ "cpu_count",
+ "gpu_count",
+ "reference",
+ )
+ ]
+ for inst in details.instances
+ },
+ {
+ "container_1": [1000, 1000, 5, 100, 2, 0, reference],
+ "container_2": [2000, 2000, 10, 200, 2, 0, []],
+ },
+ )
+ # make sure that the first call did not invalidate anything
+ new_details = self.api.serving.get_endpoint_details(endpoint_url=url)
+ self.assertEqual(details, new_details)
+
+ # charts
+ sleep(5) # give time to ES to accomodate data
+ to_date = int(time()) + 40
+ from_date = to_date - 100
+ for metric_type, title, value in (
+ (None, "Number of Requests", 3000),
+ ("requests_min", "Requests per Minute", 15),
+ ("latency_ms", "Average Latency (ms)", 150),
+ ("cpu_count", "CPU Count", 4),
+ ("cpu_util", "Average CPU Load (%)", 15),
+ ("ram_used", "RAM Used (GB)", 100.0),
+ ):
+ res = self.api.serving.get_endpoint_metrics_history(
+ endpoint_url=url,
+ from_date=from_date,
+ to_date=to_date,
+ interval=1,
+ **({"metric_type": metric_type} if metric_type else {}),
+ )
+ self.assertEqual(res.computed_interval, 40)
+ self.assertEqual(res.total.title, title)
+ length = len(res.total.dates)
+ self.assertTrue(3 >= length >= 1)
+ self.assertEqual(len(res.total["values"]), length)
+ self.assertIn(value, res.total["values"])
+ self.assertEqual(set(res.instances), {container_id1, container_id2})
+ for inst in res.instances.values():
+ self.assertEqual(inst.dates, res.total.dates)
+ self.assertEqual(len(inst["values"]), length)
+
+ # unregistering containers
+ for container_id in (container_id1, container_id2):
+ self.api.serving.unregister_container(container_id=container_id)
+ endpoints = self.api.serving.get_endpoints().endpoints
+ self.assertFalse(any(e for e in endpoints if e.url == url))
+
+ with self.api.raises(errors.bad_request.NoContainersForUrl):
+ self.api.serving.get_endpoint_details(endpoint_url=url)
diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py
index 0158391..e5de764 100644
--- a/apiserver/tests/automated/test_subprojects.py
+++ b/apiserver/tests/automated/test_subprojects.py
@@ -15,7 +15,7 @@ class TestSubProjects(TestService):
def test_dataset_stats(self):
project = self._temp_project(name="Dataset test", system_tags=["dataset"])
res = self.api.organization.get_entities_count(
- datasets={"system_tags": ["dataset"]}
+ datasets={"system_tags": ["dataset"]}, allow_public=False,
)
self.assertEqual(res.datasets, 1)
@@ -439,6 +439,15 @@ class TestSubProjects(TestService):
self.assertEqual(res2.own_tasks, 0)
self.assertEqual(res2.own_models, 0)
+ def test_public_names_clash(self):
+ # cannot create a project with a name that match public existing project
+ with self.api.raises(errors.bad_request.PublicProjectExists):
+ project = self._temp_project(name="ClearML Examples")
+
+ # cannot create a subproject under a public project
+ with self.api.raises(errors.bad_request.PublicProjectExists):
+ project = self._temp_project(name="ClearML Examples/my project")
+
def test_get_all_with_stats(self):
project4, _ = self._temp_project_with_tasks(name="project1/project3/project4")
project5, _ = self._temp_project_with_tasks(name="project1/project3/project5")
diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py
index 0be5d50..337604e 100644
--- a/apiserver/tests/automated/test_task_events.py
+++ b/apiserver/tests/automated/test_task_events.py
@@ -217,7 +217,10 @@ class TestTaskEvents(TestService):
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
-
+ self.assertEqual(0, metric_data.first_value_iteration)
+ self.assertEqual(0, metric_data.first_value)
+ self.assertEqual(iter_count, metric_data.count)
+ self.assertEqual(sum(i for i in range(iter_count)) / iter_count, metric_data.mean_value)
res = self.api.events.get_task_latest_scalar_values(task=task)
self.assertEqual(iter_count - 1, res.last_iter)
@@ -243,6 +246,7 @@ class TestTaskEvents(TestService):
"variant": f"Variant{variant_idx}",
"value": iteration,
"model_event": True,
+ "x_axis_label": f"Label_{metric_idx}_{variant_idx}"
}
for iteration in range(2)
for metric_idx in range(5)
@@ -271,6 +275,7 @@ class TestTaskEvents(TestService):
variant_data = metric_data.Variant0
self.assertEqual(variant_data.x, [0, 1])
self.assertEqual(variant_data.y, [0.0, 1.0])
+ self.assertEqual(variant_data.x_axis_label, "Label_0_0")
model_data = self.api.models.get_all_ex(
id=[model], only_fields=["last_metrics", "last_iteration"]
@@ -282,6 +287,7 @@ class TestTaskEvents(TestService):
self.assertEqual(1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
+ self.assertEqual("Label_4_4", metric_data.x_axis_label)
self._assert_log_events(task=task, expected_total=1)
diff --git a/apiserver/tests/automated/test_tasks_delete.py b/apiserver/tests/automated/test_tasks_delete.py
index 18533bc..a657402 100644
--- a/apiserver/tests/automated/test_tasks_delete.py
+++ b/apiserver/tests/automated/test_tasks_delete.py
@@ -59,7 +59,7 @@ class TestTasksResetDelete(TestService):
event_urls.update(self.send_model_events(model))
res = self.assert_delete_task(task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls)
- self.assertEqual(set(res.urls.event_urls), event_urls)
+ self.assertFalse(set(res.urls.event_urls)) # event urls are not returned anymore
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
def test_reset(self):
@@ -84,7 +84,7 @@ class TestTasksResetDelete(TestService):
) = self.create_task_with_data()
res = self.api.tasks.reset(task=task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls)
- self.assertEqual(set(res.urls.event_urls), event_urls)
+ self.assertFalse(res.urls.event_urls) # event urls are not returned anymore
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
def test_model_delete(self):
@@ -124,7 +124,7 @@ class TestTasksResetDelete(TestService):
self.assertEqual(res.disassociated_tasks, 0)
self.assertEqual(res.deleted_tasks, 1)
self.assertEqual(res.deleted_models, 2)
- self.assertEqual(set(res.urls.event_urls), event_urls)
+ self.assertFalse(set(res.urls.event_urls)) # event urls are not returned anymore
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
with self.api.raises(errors.bad_request.InvalidTaskId):
self.api.tasks.get_by_id(task=task)
diff --git a/apiserver/tests/automated/test_tasks_filtering.py b/apiserver/tests/automated/test_tasks_filtering.py
index 5b52564..72e12fb 100644
--- a/apiserver/tests/automated/test_tasks_filtering.py
+++ b/apiserver/tests/automated/test_tasks_filtering.py
@@ -71,6 +71,16 @@ class TestTasksFiltering(TestService):
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
+ # _any_/_all_ queries
+ res = self.api.tasks.get_all_ex(
+ **{"_any_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
+ ).tasks
+ self.assertTrue(set(tasks).issubset({t.id for t in res}))
+ res = self.api.tasks.get_all_ex(
+ **{"_all_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
+ ).tasks
+ self.assertFalse(set(tasks).issubset({t.id for t in res}))
+
# simplified range syntax
res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
@@ -80,6 +90,15 @@ class TestTasksFiltering(TestService):
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
+ res = self.api.tasks.get_all_ex(
+ **{"_any_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
+ ).tasks
+ self.assertTrue(set(tasks).issubset({t.id for t in res}))
+ res = self.api.tasks.get_all_ex(
+ **{"_all_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
+ ).tasks
+ self.assertFalse(set(tasks).issubset({t.id for t in res}))
+
def test_range_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()
diff --git a/apiserver/tests/automated/test_workers.py b/apiserver/tests/automated/test_workers.py
index 3108e0c..cb05724 100644
--- a/apiserver/tests/automated/test_workers.py
+++ b/apiserver/tests/automated/test_workers.py
@@ -116,7 +116,7 @@ class TestWorkersService(TestService):
if w == workers[0]:
data["task"] = task_id
self.api.workers.status_report(**data)
- timestamp += 1000
+ timestamp += 60*1000
return workers
@@ -151,7 +151,7 @@ class TestWorkersService(TestService):
time.sleep(5) # give to ES time to refresh
from_date = start
- to_date = start + 10
+ to_date = start + 40*10
# no variants
res = self.api.workers.get_stats(
items=[
@@ -180,7 +180,7 @@ class TestWorkersService(TestService):
self.assertEqual(
set(stat.aggregation for stat in metric.stats), metric_stats
)
- self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
+ self.assertTrue(11 >= len(metric.dates) >= 10)
# split by variants
res = self.api.workers.get_stats(
@@ -199,7 +199,7 @@ class TestWorkersService(TestService):
set(metric.variant for metric in worker.metrics),
{"0", "1"} if worker.worker == workers[0] else {"0"},
)
- self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
+ self.assertTrue(11 >= len(metric.dates) >= 10)
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
@@ -216,25 +216,25 @@ class TestWorkersService(TestService):
def test_get_activity_report(self):
# test no workers data
# run on an empty es db since we have no way
- # to pass non existing workers to this api
+ # to pass non-existing workers to this api
# res = self.api.workers.get_activity_report(
# from_timestamp=from_timestamp.timestamp(),
# to_timestamp=to_timestamp.timestamp(),
# interval=20,
# )
start = int(time.time())
- self._simulate_workers(int(time.time()))
+ self._simulate_workers(start)
time.sleep(5) # give to es time to refresh
# no variants
res = self.api.workers.get_activity_report(
- from_date=start, to_date=start + 10, interval=2
+ from_date=start, to_date=start + 10*40, interval=2
)
- self.assertWorkerSeries(res["total"], 2, 5)
- self.assertWorkerSeries(res["active"], 1, 5)
+ self.assertWorkerSeries(res["total"], 2, 10)
+ self.assertWorkerSeries(res["active"], 1, 10)
def assertWorkerSeries(self, series_data: dict, count: int, size: int):
self.assertEqual(len(series_data["dates"]), size)
self.assertEqual(len(series_data["counts"]), size)
- self.assertTrue(any(c == count for c in series_data["counts"]))
- self.assertTrue(all(c <= count for c in series_data["counts"]))
+ # self.assertTrue(any(c == count for c in series_data["counts"]))
+ # self.assertTrue(all(c <= count for c in series_data["counts"]))
diff --git a/apiserver/utilities/dicts.py b/apiserver/utilities/dicts.py
index 3850f6e..e794dec 100644
--- a/apiserver/utilities/dicts.py
+++ b/apiserver/utilities/dicts.py
@@ -1,4 +1,4 @@
-from typing import Sequence, Tuple, Any, Union, Callable, Optional, Mapping
+from typing import Sequence, Tuple, Any, Union, Callable, Optional, Protocol
def flatten_nested_items(
@@ -35,8 +35,13 @@ def deep_merge(source: dict, override: dict) -> dict:
return source
+class GetItem(Protocol):
+ def __getitem__(self, key: Any) -> Any:
+ pass
+
+
def nested_get(
- dictionary: Mapping,
+ dictionary: GetItem,
path: Sequence[str],
default: Optional[Union[Any, Callable]] = None,
) -> Any:
diff --git a/apiserver/version.py b/apiserver/version.py
index 638c121..8c0d5d5 100644
--- a/apiserver/version.py
+++ b/apiserver/version.py
@@ -1 +1 @@
-__version__ = "1.16.0"
+__version__ = "2.0.0"
diff --git a/docker/docker-compose-win10.yml b/docker/docker-compose-win10.yml
index 6f52412..77b9ea2 100644
--- a/docker/docker-compose-win10.yml
+++ b/docker/docker-compose-win10.yml
@@ -58,7 +58,7 @@ services:
nofile:
soft: 65536
hard: 65536
- image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
+ image: elasticsearch:8.17.0
restart: unless-stopped
volumes:
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -87,7 +87,7 @@ services:
networks:
- backend
container_name: clearml-mongo
- image: mongo:4.4.29
+ image: mongo:6.0.19
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -98,7 +98,7 @@ services:
networks:
- backend
container_name: clearml-redis
- image: redis:6.2
+ image: redis:7.4.1
restart: unless-stopped
volumes:
- c:/opt/clearml/data/redis:/data
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index 07f7f43..f2043ab 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -60,7 +60,7 @@ services:
nofile:
soft: 65536
hard: 65536
- image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
+ image: elasticsearch:8.17.0
restart: unless-stopped
volumes:
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -88,7 +88,7 @@ services:
networks:
- backend
container_name: clearml-mongo
- image: mongo:4.4.29
+ image: mongo:6.0.19
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -99,7 +99,7 @@ services:
networks:
- backend
container_name: clearml-redis
- image: redis:6.2
+ image: redis:7.4.1
restart: unless-stopped
volumes:
- /opt/clearml/data/redis:/data
diff --git a/fileserver/LICENSE b/fileserver/LICENSE
index 24d07bf..19931a1 100644
--- a/fileserver/LICENSE
+++ b/fileserver/LICENSE
@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
- Copyright © 2019 allegro.ai, Inc.
+ Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
diff --git a/fileserver/fileserver.py b/fileserver/fileserver.py
index a895580..bdc965c 100644
--- a/fileserver/fileserver.py
+++ b/fileserver/fileserver.py
@@ -7,6 +7,7 @@ import urllib.parse
from argparse import ArgumentParser
from collections import defaultdict
from pathlib import Path
+from typing import Optional
from boltons.iterutils import first
from flask import Flask, request, send_from_directory, abort, Response
@@ -113,8 +114,12 @@ def download(path):
return response
-def _get_full_path(path: str) -> Path:
- return Path(safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path)))
+def _get_full_path(path: str) -> Optional[Path]:
+ path_str = safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path))
+ if path_str is None:
+ return path_str
+
+ return Path(path_str)
@app.route("/", methods=["DELETE"])
@@ -123,7 +128,7 @@ def delete(path):
auth_handler.validate(request)
full_path = _get_full_path(path)
- if not full_path.exists() or not full_path.is_file():
+ if not (full_path and full_path.exists() and full_path.is_file()):
log.error(f"Error deleting file {str(full_path)}. Not found or not a file")
abort(Response(f"File {str(path)} not found", 404))
@@ -161,7 +166,7 @@ def batch_delete():
full_path = _get_full_path(path)
- if not full_path.exists():
+ if not (full_path and full_path.exists()):
record_error("Not found", file, path)
continue
diff --git a/fileserver/requirements.txt b/fileserver/requirements.txt
index 8395cab..28f978e 100644
--- a/fileserver/requirements.txt
+++ b/fileserver/requirements.txt
@@ -5,7 +5,7 @@ flask-cors>=3.0.5
flask>=2.3.3
gunicorn>=20.1.0
pyhocon>=0.3.35
-redis>=4.5.4,<5
+redis==5.2.1
setuptools>=65.5.1
urllib3>=1.26.18
werkzeug>=3.0.1
\ No newline at end of file
diff --git a/upgrade/1_17_to_2_0/docker-compose-win10.yml b/upgrade/1_17_to_2_0/docker-compose-win10.yml
new file mode 100644
index 0000000..f578182
--- /dev/null
+++ b/upgrade/1_17_to_2_0/docker-compose-win10.yml
@@ -0,0 +1,158 @@
+version: "3.6"
+services:
+
+ apiserver:
+ command:
+ - apiserver
+ container_name: clearml-apiserver
+ image: allegroai/clearml:1.17.1-554
+ restart: unless-stopped
+ volumes:
+ - c:/opt/clearml/logs:/var/log/clearml
+ - c:/opt/clearml/config:/opt/clearml/config
+ - c:/opt/clearml/data/fileserver:/mnt/fileserver
+ depends_on:
+ - redis
+ - mongo
+ - elasticsearch
+ - fileserver
+ environment:
+ CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
+ CLEARML_ELASTIC_SERVICE_PORT: 9200
+ CLEARML_MONGODB_SERVICE_HOST: mongo
+ CLEARML_MONGODB_SERVICE_PORT: 27017
+ CLEARML_REDIS_SERVICE_HOST: redis
+ CLEARML_REDIS_SERVICE_PORT: 6379
+ CLEARML_SERVER_DEPLOYMENT_TYPE: win10
+ CLEARML__apiserver__pre_populate__enabled: "true"
+ CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
+ CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
+ CLEARML__services__async_urls_delete__enabled: "true"
+ CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
+ ports:
+ - "8008:8008"
+ networks:
+ - backend
+ - frontend
+
+ elasticsearch:
+ networks:
+ - backend
+ container_name: clearml-elastic
+ environment:
+ bootstrap.memory_lock: "true"
+ cluster.name: clearml
+ cluster.routing.allocation.node_initial_primaries_recoveries: "500"
+ cluster.routing.allocation.disk.watermark.low: 500mb
+ cluster.routing.allocation.disk.watermark.high: 500mb
+ cluster.routing.allocation.disk.watermark.flood_stage: 500mb
+ discovery.type: "single-node"
+ http.compression_level: "7"
+ node.name: clearml
+ reindex.remote.whitelist: "'*.*'"
+ xpack.security.enabled: "false"
+ ulimits:
+ memlock:
+ soft: -1
+ hard: -1
+ nofile:
+ soft: 65536
+ hard: 65536
+ image: docker.elastic.co/elasticsearch/elasticsearch:8.15.3
+ restart: unless-stopped
+ volumes:
+ - c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
+ - /usr/share/elasticsearch/logs
+
+ fileserver:
+ networks:
+ - backend
+ - frontend
+ command:
+ - fileserver
+ container_name: clearml-fileserver
+ image: allegroai/clearml:1.17.1-554
+ environment:
+ CLEARML__fileserver__delete__allow_batch: "true"
+ restart: unless-stopped
+ volumes:
+ - c:/opt/clearml/logs:/var/log/clearml
+ - c:/opt/clearml/data/fileserver:/mnt/fileserver
+ - c:/opt/clearml/config:/opt/clearml/config
+
+ ports:
+ - "8081:8081"
+
+ mongo:
+ networks:
+ - backend
+ container_name: clearml-mongo
+ image: mongo:5.0.26
+ restart: unless-stopped
+ command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
+ volumes:
+ - c:/opt/clearml/data/mongo_4/db:/data/db
+ - c:/opt/clearml/data/mongo_4/configdb:/data/configdb
+
+ redis:
+ networks:
+ - backend
+ container_name: clearml-redis
+ image: redis:6.2
+ restart: unless-stopped
+ volumes:
+ - c:/opt/clearml/data/redis:/data
+
+ webserver:
+ command:
+ - webserver
+ container_name: clearml-webserver
+ image: allegroai/clearml:1.17.1-554
+ restart: unless-stopped
+ volumes:
+ - c:/clearml/logs:/var/log/clearml
+ depends_on:
+ - apiserver
+ ports:
+ - "8080:80"
+ networks:
+ - backend
+ - frontend
+
+ async_delete:
+ depends_on:
+ - apiserver
+ - redis
+ - mongo
+ - elasticsearch
+ - fileserver
+ container_name: async_delete
+ image: allegroai/clearml:1.17.1-554
+ networks:
+ - backend
+ restart: unless-stopped
+ environment:
+ CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
+ CLEARML_ELASTIC_SERVICE_PORT: 9200
+ CLEARML_MONGODB_SERVICE_HOST: mongo
+ CLEARML_MONGODB_SERVICE_PORT: 27017
+ CLEARML_REDIS_SERVICE_HOST: redis
+ CLEARML_REDIS_SERVICE_PORT: 6379
+ PYTHONPATH: /opt/clearml/apiserver
+ CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
+ entrypoint:
+ - python3
+ - -m
+ - jobs.async_urls_delete
+ - --fileserver-host
+ - http://fileserver:8081
+ volumes:
+ - c:/opt/clearml/logs:/var/log/clearml
+ - c:/opt/clearml/config:/opt/clearml/config
+
+networks:
+ backend:
+ driver: bridge
+ frontend:
+ name: frontend
+ driver: bridge
\ No newline at end of file
diff --git a/upgrade/1_17_to_2_0/docker-compose.yml b/upgrade/1_17_to_2_0/docker-compose.yml
new file mode 100644
index 0000000..4a7684a
--- /dev/null
+++ b/upgrade/1_17_to_2_0/docker-compose.yml
@@ -0,0 +1,195 @@
+version: "3.6"
+services:
+
+ apiserver:
+ command:
+ - apiserver
+ container_name: clearml-apiserver
+ image: allegroai/clearml:1.17.1-554
+ restart: unless-stopped
+ volumes:
+ - /opt/clearml/logs:/var/log/clearml
+ - /opt/clearml/config:/opt/clearml/config
+ - /opt/clearml/data/fileserver:/mnt/fileserver
+ depends_on:
+ - redis
+ - mongo
+ - elasticsearch
+ - fileserver
+ environment:
+ CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
+ CLEARML_ELASTIC_SERVICE_PORT: 9200
+ CLEARML_MONGODB_SERVICE_HOST: mongo
+ CLEARML_MONGODB_SERVICE_PORT: 27017
+ CLEARML_REDIS_SERVICE_HOST: redis
+ CLEARML_REDIS_SERVICE_PORT: 6379
+ CLEARML_SERVER_DEPLOYMENT_TYPE: linux
+ CLEARML__apiserver__pre_populate__enabled: "true"
+ CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
+ CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
+ CLEARML__services__async_urls_delete__enabled: "true"
+ CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
+ CLEARML__secure__credentials__services_agent__user_key: ${CLEARML_AGENT_ACCESS_KEY:-}
+ CLEARML__secure__credentials__services_agent__user_secret: ${CLEARML_AGENT_SECRET_KEY:-}
+ ports:
+ - "8008:8008"
+ networks:
+ - backend
+ - frontend
+
+ elasticsearch:
+ networks:
+ - backend
+ container_name: clearml-elastic
+ environment:
+ bootstrap.memory_lock: "true"
+ cluster.name: clearml
+ cluster.routing.allocation.node_initial_primaries_recoveries: "500"
+ cluster.routing.allocation.disk.watermark.low: 500mb
+ cluster.routing.allocation.disk.watermark.high: 500mb
+ cluster.routing.allocation.disk.watermark.flood_stage: 500mb
+ discovery.type: "single-node"
+ http.compression_level: "7"
+ node.name: clearml
+ reindex.remote.whitelist: "'*.*'"
+ xpack.security.enabled: "false"
+ ulimits:
+ memlock:
+ soft: -1
+ hard: -1
+ nofile:
+ soft: 65536
+ hard: 65536
+ image: docker.elastic.co/elasticsearch/elasticsearch:8.15.3
+ restart: unless-stopped
+ volumes:
+ - /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
+ - /usr/share/elasticsearch/logs
+
+ fileserver:
+ networks:
+ - backend
+ - frontend
+ command:
+ - fileserver
+ container_name: clearml-fileserver
+ image: allegroai/clearml:1.17.1-554
+ environment:
+ CLEARML__fileserver__delete__allow_batch: "true"
+ restart: unless-stopped
+ volumes:
+ - /opt/clearml/logs:/var/log/clearml
+ - /opt/clearml/data/fileserver:/mnt/fileserver
+ - /opt/clearml/config:/opt/clearml/config
+ ports:
+ - "8081:8081"
+
+ mongo:
+ networks:
+ - backend
+ container_name: clearml-mongo
+ image: mongo:5.0.26
+ restart: unless-stopped
+ command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
+ volumes:
+ - /opt/clearml/data/mongo_4/db:/data/db
+ - /opt/clearml/data/mongo_4/configdb:/data/configdb
+
+ redis:
+ networks:
+ - backend
+ container_name: clearml-redis
+ image: redis:6.2
+ restart: unless-stopped
+ volumes:
+ - /opt/clearml/data/redis:/data
+
+ webserver:
+ command:
+ - webserver
+ container_name: clearml-webserver
+ # environment:
+ # CLEARML_SERVER_SUB_PATH : clearml-web # Allow Clearml to be served with a URL path prefix.
+ image: allegroai/clearml:1.17.1-554
+ restart: unless-stopped
+ depends_on:
+ - apiserver
+ ports:
+ - "8080:80"
+ networks:
+ - backend
+ - frontend
+
+ async_delete:
+ depends_on:
+ - apiserver
+ - redis
+ - mongo
+ - elasticsearch
+ - fileserver
+ container_name: async_delete
+ image: allegroai/clearml:1.17.1-554
+ networks:
+ - backend
+ restart: unless-stopped
+ environment:
+ CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
+ CLEARML_ELASTIC_SERVICE_PORT: 9200
+ CLEARML_MONGODB_SERVICE_HOST: mongo
+ CLEARML_MONGODB_SERVICE_PORT: 27017
+ CLEARML_REDIS_SERVICE_HOST: redis
+ CLEARML_REDIS_SERVICE_PORT: 6379
+ PYTHONPATH: /opt/clearml/apiserver
+ CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
+ entrypoint:
+ - python3
+ - -m
+ - jobs.async_urls_delete
+ - --fileserver-host
+ - http://fileserver:8081
+ volumes:
+ - /opt/clearml/logs:/var/log/clearml
+ - /opt/clearml/config:/opt/clearml/config
+
+ agent-services:
+ networks:
+ - backend
+ container_name: clearml-agent-services
+ image: allegroai/clearml-agent-services:latest
+ deploy:
+ restart_policy:
+ condition: on-failure
+ privileged: true
+ environment:
+ CLEARML_HOST_IP: ${CLEARML_HOST_IP}
+ CLEARML_WEB_HOST: ${CLEARML_WEB_HOST:-}
+ CLEARML_API_HOST: http://apiserver:8008
+ CLEARML_FILES_HOST: ${CLEARML_FILES_HOST:-}
+ CLEARML_API_ACCESS_KEY: ${CLEARML_AGENT_ACCESS_KEY:-$CLEARML_API_ACCESS_KEY}
+ CLEARML_API_SECRET_KEY: ${CLEARML_AGENT_SECRET_KEY:-$CLEARML_API_SECRET_KEY}
+ CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
+ CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
+ CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}
+ CLEARML_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
+ AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
+ AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
+ AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
+ AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
+ AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
+ GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
+ CLEARML_WORKER_ID: "clearml-services"
+ CLEARML_AGENT_DOCKER_HOST_MOUNT: "/opt/clearml/agent:/root/.clearml"
+ SHUTDOWN_IF_NO_ACCESS_KEY: 1
+ volumes:
+ - /var/run/docker.sock:/var/run/docker.sock
+ - /opt/clearml/agent:/root/.clearml
+ depends_on:
+ - apiserver
+ entrypoint: >
+ bash -c "curl --retry 10 --retry-delay 10 --retry-connrefused 'http://apiserver:8008/debug.ping' && /usr/agent/entrypoint.sh"
+
+networks:
+ backend:
+ driver: bridge
+ frontend:
+ driver: bridge