This commit is contained in:
revital
2025-03-02 11:47:20 +02:00
93 changed files with 4282 additions and 761 deletions

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2019 allegro.ai, Inc.
Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -7,42 +7,15 @@
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.9-blue.svg)](https://img.shields.io/badge/python-3.9-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/allegroai)](https://artifacthub.io/packages/search?repo=allegroai)
[![GitHub version](https://img.shields.io/github/release-pre/clearml/trains-server.svg)](https://img.shields.io/github/release-pre/clearml/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/clearml)](https://artifacthub.io/packages/search?repo=clearml)
</div>
---
<div align="center">
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
</div>
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
due to Elasticsearchs usage of the Java Security Manager.
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
(which in any case **does not permit access to data within the Elasticsearch cluster**),
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
---
## ClearML Server
#### *Formerly known as Trains Server*
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/clearml/clearml).
It allows multiple users to collaborate and manage their experiments.
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
@@ -99,8 +72,10 @@ Launch The **ClearML Server** in any of the following formats:
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
- Kubernetes
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
## Connecting ClearML to your ClearML Server
In order to set up the **ClearML** client to work with your **ClearML Server**:
@@ -118,16 +93,13 @@ In order to set up the **ClearML** client to work with your **ClearML Server**:
files_server: "http://localhost:8081"
}
> [!NOTE]
>
> If you have set up your **ClearML Server** in a sub-domain configuration, then there is no need to specify a port number,
**Note**: If you have set up your **ClearML Server** in a sub-domain configuration, then there is no need to specify a port number,
it will be inferred from the http/s scheme.
After launching the ClearML Server and configuring the **ClearML** client to use the ClearML Server,
you can use [ClearML](https://github.com/allegroai/clearml) in your experiments and view them in your ClearML Server web server,
After launching the **ClearML Server** and configuring the **ClearML** client to use the **ClearML Server**,
you can [use](https://github.com/clearml/clearml) **ClearML** in your experiments and view them in your **ClearML Server** web server,
for example http://localhost:8080.
For more information about the ClearML client, see [**ClearML**](https://github.com/allegroai/clearml).
For more information about the ClearML client, see [**ClearML**](https://github.com/clearml/clearml).
## ClearML-Agent Services <a name="services"></a>
@@ -144,11 +116,9 @@ increased data transparency)
ClearML-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
Every task launched by ClearML-Agent Services will be registered as a new node in the system,
providing tracking and transparency capabilities.
You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/allegroai/clearml-agent#clearml-agent-services-mode-)
You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/clearml/clearml-agent#clearml-agent-services-mode-)
> [!NOTE]
>
> It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
## Advanced Functionality
@@ -169,7 +139,7 @@ To restart the **ClearML Server**, you must first stop the containers, and then
## Upgrading <a name="upgrade"></a>
**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker/docker-compose.yml).
**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/clearml/trains-server/blob/master/docker/docker-compose.yml).
We strongly encourage you to keep your **ClearML Server** up to date, by keeping up with the current release.
**Note**: The following upgrade instructions use the Linux OS as an example.
@@ -202,7 +172,7 @@ To upgrade your existing **ClearML Server** deployment:
1. Download the latest `docker-compose.yml` file.
```bash
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker/docker-compose.yml -o docker-compose.yml
curl https://raw.githubusercontent.com/clearml/trains-server/master/docker/docker-compose.yml -o docker-compose.yml
```
1. Configure the ClearML-Agent Services (not supported on Windows installation).
@@ -227,10 +197,10 @@ To upgrade your existing **ClearML Server** deployment:
## Community & Support
If you have any questions, look at the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
For feature requests or bug reports, please use [GitHub issues](https://github.com/clearml/clearml-server/issues).
Additionally, you can always find us at *clearml@allegro.ai*

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2019 allegro.ai, Inc.
Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -84,6 +84,7 @@
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
413: ["project_has_pipelines", "project has associated pipelines with active controllers"]
414: ["public_project_exists", "Cannot create project. Public project with the same name already exists"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]
@@ -106,6 +107,11 @@
1004: ["worker_not_registered", "worker is not registered"]
1005: ["worker_stats_not_found", "worker stats not found"]
# Serving
1050: ["invalid_container_id", "invalid container id"]
1051: ["container_not_registered", "container is not registered"]
1052: ["no_containers_for_url", "no container instances found for serice url"]
1104: ["invalid_scroll_id", "Invalid scroll id"]
}

View File

@@ -1,10 +1,11 @@
from enum import Enum
from typing import Union, Type, Iterable
from numbers import Number
from typing import Union, Type, Iterable, Mapping
import jsonmodels.errors
import six
from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.fields import _LazyType, NotSet, EmbeddedField
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from mongoengine.base import BaseDocument
@@ -40,6 +41,34 @@ def make_default(field_cls, default_value):
return _FieldWithDefault
class OneOfEmbeddedField(EmbeddedField):
def __init__(
self,
*args,
discriminator_property: str,
discriminator_mapping: Mapping[str, type],
**kwargs,
):
self.discriminator_property = discriminator_property
self.discriminator_mapping = discriminator_mapping
model_types = tuple(set(self.discriminator_mapping.values()))
super().__init__(model_types, *args, **kwargs)
def parse_value(self, value):
"""Parse value to proper model type."""
if not isinstance(value, dict) or self.discriminator_property not in value:
return super().parse_value(value)
property_value = value.get(self.discriminator_property)
embed_type = self.discriminator_mapping.get(property_value)
if not embed_type:
raise jsonmodels.errors.ValidationError(
f"Could not find type matching discriminator property value: {property_value}"
)
return embed_type(**value)
class ListField(fields.ListField):
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
if default is not NotSet and callable(default):
@@ -68,6 +97,15 @@ class ScalarField(fields.BaseField):
types = (str, int, float, bool)
class SafeStringField(fields.StringField):
"""String field that can also accept numbers as input"""
def parse_value(self, value):
if isinstance(value, Number):
value = str(value)
return super().parse_value(value)
class DictField(fields.BaseField):
types = (dict,)
@@ -115,9 +153,7 @@ class DictField(fields.BaseField):
if len(self.value_types) != 1:
tpl = 'Cannot decide which type to choose from "{types}".'
raise jsonmodels.errors.ValidationError(
tpl.format(
types=', '.join([t.__name__ for t in self.value_types])
)
tpl.format(types=", ".join([t.__name__ for t in self.value_types]))
)
return self.value_types[0](**value)
@@ -179,7 +215,7 @@ class EnumField(fields.StringField):
*args,
required=False,
default=None,
**kwargs
**kwargs,
):
choices = list(map(self.parse_value, values_or_type))
validator_cls = EnumValidator if required else NullableEnumValidator
@@ -202,7 +238,7 @@ class ActualEnumField(fields.StringField):
validators=None,
required=False,
default=None,
**kwargs
**kwargs,
):
self.__enum = enum_class
self.types = (enum_class,)
@@ -215,7 +251,7 @@ class ActualEnumField(fields.StringField):
*args,
required=required,
validators=validators,
**kwargs
**kwargs,
)
def parse_value(self, value):

View File

@@ -17,6 +17,7 @@ class GetDefaultResp(Base):
class CreateRequest(Base):
name = StringField(required=True)
display_name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
@@ -47,6 +48,7 @@ class DeleteRequest(QueueRequest):
class UpdateRequest(QueueRequest):
name = StringField()
display_name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
@@ -56,6 +58,14 @@ class TaskRequest(QueueRequest):
task = StringField(required=True)
class RemoveTaskRequest(TaskRequest):
update_task_status = BoolField(default=False)
class AddTaskRequest(TaskRequest):
update_execution_queue = BoolField(default=True)
class MoveTaskRequest(TaskRequest):
count = IntField(default=1)

View File

@@ -0,0 +1,104 @@
from enum import Enum
from typing import Sequence
from jsonmodels.models import Base
from jsonmodels.fields import (
StringField,
EmbeddedField,
DateTimeField,
IntField,
FloatField,
BoolField,
)
from jsonmodels import validators
from jsonmodels.validators import Min
from apiserver.apimodels import ListField, JsonSerializableMixin, SafeStringField
from apiserver.apimodels import ActualEnumField
from apiserver.config_repo import config
from .workers import MachineStats
class ReferenceItem(Base):
type = StringField(
required=True,
validators=validators.Enum("app_id", "app_instance", "model", "task", "url"),
)
value = StringField(required=True)
class ServingModel(Base):
container_id = StringField(required=True)
endpoint_name = StringField(required=True)
endpoint_url = StringField() # can be not existing yet at registration time
model_name = StringField(required=True)
model_source = StringField()
model_version = StringField()
preprocess_artifact = StringField()
input_type = StringField()
input_size = SafeStringField()
tags = ListField(str)
system_tags = ListField(str)
reference: Sequence[ReferenceItem] = ListField(ReferenceItem)
class RegisterRequest(ServingModel):
timeout = IntField(
default=int(
config.get("services.serving.default_container_timeout_sec", 10 * 60)
),
validators=[Min(1)],
)
""" registration timeout in seconds (default is 10min) """
class UnregisterRequest(Base):
container_id = StringField(required=True)
class StatusReportRequest(ServingModel):
uptime_sec = IntField()
requests_num = IntField()
requests_min = FloatField()
latency_ms = IntField()
machine_stats: MachineStats = EmbeddedField(MachineStats)
class ServingContainerEntry(StatusReportRequest, JsonSerializableMixin):
key = StringField(required=True)
company_id = StringField(required=True)
ip = StringField()
register_time = DateTimeField(required=True)
register_timeout = IntField(required=True)
last_activity_time = DateTimeField(required=True)
class GetEndpointDetailsRequest(Base):
endpoint_url = StringField(required=True)
class MetricType(Enum):
requests = "requests"
requests_min = "requests_min"
latency_ms = "latency_ms"
cpu_count = "cpu_count"
gpu_count = "gpu_count"
cpu_util = "cpu_util"
gpu_util = "gpu_util"
ram_total = "ram_total"
ram_used = "ram_used"
ram_free = "ram_free"
gpu_ram_total = "gpu_ram_total"
gpu_ram_used = "gpu_ram_used"
gpu_ram_free = "gpu_ram_free"
network_rx = "network_rx"
network_tx = "network_tx"
class GetEndpointMetricsHistoryRequest(Base):
from_date = FloatField(required=True, validators=Min(0))
to_date = FloatField(required=True, validators=Min(0))
interval = IntField(required=True, validators=Min(1))
endpoint_url = StringField(required=True)
metric_type = ActualEnumField(MetricType, default=MetricType.requests)
instance_charts = BoolField(default=True)

View File

@@ -0,0 +1,60 @@
from jsonmodels.fields import StringField, BoolField, ListField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Enum
class AWSBucketSettings(Base):
bucket = StringField()
subdir = StringField()
host = StringField()
key = StringField()
secret = StringField()
token = StringField()
multipart = BoolField(default=True)
acl = StringField()
secure = BoolField(default=True)
region = StringField()
verify = BoolField(default=True)
use_credentials_chain = BoolField(default=False)
class AWSSettings(Base):
key = StringField()
secret = StringField()
region = StringField()
token = StringField()
use_credentials_chain = BoolField(default=False)
buckets = ListField(items_types=[AWSBucketSettings])
class GoogleBucketSettings(Base):
bucket = StringField()
subdir = StringField()
project = StringField()
credentials_json = StringField()
class GoogleSettings(Base):
project = StringField()
credentials_json = StringField()
buckets = ListField(items_types=[GoogleBucketSettings])
class AzureContainerSettings(Base):
account_name = StringField()
account_key = StringField()
container_name = StringField()
class AzureSettings(Base):
containers = ListField(items_types=[AzureContainerSettings])
class SetSettingsRequest(Base):
aws = EmbeddedField(AWSSettings)
google = EmbeddedField(GoogleSettings)
azure = EmbeddedField(AzureSettings)
class ResetSettingsRequest(Base):
keys = ListField([str], item_validators=[Enum("aws", "google", "azure")])

View File

@@ -109,6 +109,7 @@ class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
update_execution_queue = BoolField(default=True)
class DeleteRequest(UpdateRequest):

View File

@@ -86,6 +86,7 @@ class CurrentTaskEntry(IdNameEntry):
class QueueEntry(IdNameEntry):
display_name = StringField()
next_task = EmbeddedField(IdNameEntry)
num_tasks = IntField()

View File

@@ -41,7 +41,7 @@ from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.model.task.task import TaskStatus
from apiserver.redis_manager import redman
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get
@@ -201,6 +201,8 @@ class EventBLL(object):
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
for event in events:
x_axis_label = event.pop("x_axis_label", None)
# remove spaces from event type
event_type = event.get("type")
if event_type is None:
@@ -296,6 +298,7 @@ class EventBLL(object):
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
x_axis_label=x_axis_label,
)
actions.append(es_action)
@@ -319,6 +322,7 @@ class EventBLL(object):
if actions:
chunk_size = 500
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
# noinspection PyTypeChecker
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
@@ -430,7 +434,7 @@ class EventBLL(object):
return False
return True
def _update_last_scalar_events_for_task(self, last_events, event):
def _update_last_scalar_events_for_task(self, last_events, event, x_axis_label=None):
"""
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
@@ -438,45 +442,47 @@ class EventBLL(object):
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
value = event.get("value")
if value is None:
return
metric = event.get("metric") or ""
variant = event.get("variant") or ""
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
last_event = last_events[metric_hash][variant_hash]
last_event["metric"] = metric
last_event["variant"] = variant
last_event["count"] = last_event.get("count", 0) + 1
last_event["total"] = last_event.get("total", 0) + value
event_iter = event.get("iter", 0)
event_timestamp = event.get("timestamp", 0)
value = event.get("value")
if value is not None and (
(event_iter, event_timestamp)
>= (
last_event.get("iter", event_iter),
last_event.get("timestamp", event_timestamp),
)
if (event_iter, event_timestamp) >= (
last_event.get("iter", event_iter),
last_event.get("timestamp", event_timestamp),
):
event_data = {
k: event[k]
for k in ("value", "metric", "variant", "iter", "timestamp")
if k in event
}
last_event_min_value = last_event.get("min_value", value)
last_event_min_value_iter = last_event.get("min_value_iter", event_iter)
if value < last_event_min_value:
event_data["min_value"] = value
event_data["min_value_iter"] = event_iter
else:
event_data["min_value"] = last_event_min_value
event_data["min_value_iter"] = last_event_min_value_iter
last_event_max_value = last_event.get("max_value", value)
last_event_max_value_iter = last_event.get("max_value_iter", event_iter)
if value > last_event_max_value:
event_data["max_value"] = value
event_data["max_value_iter"] = event_iter
else:
event_data["max_value"] = last_event_max_value
event_data["max_value_iter"] = last_event_max_value_iter
last_events[metric_hash][variant_hash] = event_data
last_event["value"] = value
last_event["iter"] = event_iter
last_event["timestamp"] = event_timestamp
if x_axis_label is not None:
last_event["x_axis_label"] = x_axis_label
first_value_iter = last_event.get("first_value_iter")
if first_value_iter is None or event_iter < first_value_iter:
last_event["first_value"] = value
last_event["first_value_iter"] = event_iter
last_event_min_value = last_event.get("min_value")
if last_event_min_value is None or value < last_event_min_value:
last_event["min_value"] = value
last_event["min_value_iter"] = event_iter
last_event_max_value = last_event.get("max_value")
if last_event_max_value is None or value > last_event_max_value:
last_event["max_value"] = value
last_event["max_value_iter"] = event_iter
def _update_last_metric_events_for_task(self, last_events, event):
"""
@@ -659,7 +665,9 @@ class EventBLL(object):
Release the scroll once it is exhausted
"""
total_events = nested_get(es_res, ("hits", "total", "value"), default=0)
events = [doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])]
events = [
doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])
]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.clear_scroll(next_scroll_id)
@@ -1148,34 +1156,6 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
@staticmethod
def _validate_model_state(
company_id: str, model_id: str, allow_locked: bool = False
):
extra_msg = None
query = Q(id=model_id, company=company_id)
if not allow_locked:
query &= Q(ready__ne=True)
extra_msg = "or model published"
res = Model.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidModelId(
extra_msg, company=company_id, id=model_id
)
@staticmethod
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
@staticmethod
def _get_events_deletion_params(async_delete: bool) -> dict:
if async_delete:
@@ -1188,51 +1168,53 @@ class EventBLL(object):
return {"refresh": True}
def delete_task_events(self, company_id, task_id, allow_locked=False, model=False):
if model:
self._validate_model_state(
company_id=company_id,
model_id=task_id,
allow_locked=allow_locked,
)
else:
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
async_delete = async_task_events_delete
if async_delete:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=[task_id],
)
if total <= async_delete_threshold:
async_delete = False
es_req = {"query": {"term": {"task": task_id}}}
def delete_task_events(
self,
company_id,
task_ids: Union[str, Sequence[str]],
wait_for_delete: bool,
model=False,
):
"""
Delete task events. No check is done for tasks write access
so it should be checked by the calling code
"""
if isinstance(task_ids, str):
task_ids = [task_ids]
deleted = 0
with translate_errors_context():
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
**self._get_events_deletion_params(async_delete),
)
async_delete = async_task_events_delete and not wait_for_delete
if async_delete and len(task_ids) < 100:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=task_ids,
)
if total <= async_delete_threshold:
async_delete = False
for tasks in chunked_iter(task_ids, 100):
es_req = {"query": {"terms": {"task": tasks}}}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
**self._get_events_deletion_params(async_delete),
)
if not async_delete:
deleted += es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
return deleted
def clear_task_log(
self,
company_id: str,
task_id: str,
allow_locked: bool = False,
threshold_sec: int = None,
include_metrics: Sequence[str] = None,
exclude_metrics: Sequence[str] = None,
):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.task_log
):
@@ -1274,39 +1256,6 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def delete_multi_task_events(
self, company_id: str, task_ids: Sequence[str], model=False
):
"""
Delete multiple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
deleted = 0
with translate_errors_context():
async_delete = async_task_events_delete
if async_delete and len(task_ids) < 100:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=task_ids,
)
if total <= async_delete_threshold:
async_delete = False
for tasks in chunked_iter(task_ids, 100):
es_req = {"query": {"terms": {"task": tasks}}}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
**self._get_events_deletion_params(async_delete),
)
if not async_delete:
deleted += es_res.get("deleted", 0)
if not async_delete:
return deleted
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:
return

View File

@@ -24,6 +24,8 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
log = config.logger(__file__)
@@ -43,6 +45,7 @@ class EventMetrics:
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
model_events: bool = False,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@@ -60,6 +63,7 @@ class EventMetrics:
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
model_events=model_events,
)
def _get_scalar_average_per_iter_core(
@@ -71,6 +75,7 @@ class EventMetrics:
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
model_events: bool = False,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@@ -102,7 +107,22 @@ class EventMetrics:
)
ret = defaultdict(dict)
if not metrics:
return ret
last_metrics = {}
cls_ = Model if model_events else Task
task = cls_.objects(id=task_id).only("last_metrics").first()
if task and task.last_metrics:
for m_data in task.last_metrics.values():
for v_data in m_data.values():
last_metrics[(v_data.metric, v_data.variant)] = v_data
for metric_key, metric_values in metrics:
for variant_key, data in metric_values.items():
last_metrics_data = last_metrics.get((metric_key, variant_key))
if last_metrics_data and last_metrics_data.x_axis_label is not None:
data["x_axis_label"] = last_metrics_data.x_axis_label
ret[metric_key].update(metric_values)
return ret
@@ -113,6 +133,7 @@ class EventMetrics:
samples,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
model_events: bool = False,
):
"""
Compare scalar metrics for different tasks per metric and variant
@@ -136,6 +157,7 @@ class EventMetrics:
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
run_parallel=False,
model_events=model_events,
)
task_ids, company_ids = zip(
*(
@@ -165,7 +187,7 @@ class EventMetrics:
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, dict]:
) -> Mapping[str, Sequence[dict]]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""

View File

@@ -183,7 +183,7 @@ class HistoryDebugImageIterator:
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}],
"sort": [{"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
@@ -242,7 +242,7 @@ class HistoryDebugImageIterator:
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"sort": [{"iter": order}, {"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
@@ -338,7 +338,7 @@ class HistoryDebugImageIterator:
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"sort": [{"iter": "desc"}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}

View File

@@ -384,7 +384,8 @@ class MetricEventsIterator:
"aggs": {
"events": {
"top_hits": {
"sort": self._get_same_variant_events_order()
"sort": self._get_same_variant_events_order(),
"size": 1,
}
}
},

View File

@@ -6,7 +6,6 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
@@ -15,8 +14,6 @@ from .metadata import Metadata
class ModelBLL:
event_bll = None
@classmethod
def get_company_model_by_id(
cls, company_id: str, model_id: str, only_fields=None
@@ -94,7 +91,7 @@ class ModelBLL:
@classmethod
def delete_model(
cls, model_id: str, company_id: str, user_id: str, force: bool, delete_external_artifacts: bool = True,
cls, model_id: str, company_id: str, user_id: str, force: bool
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
@@ -147,34 +144,6 @@ class ModelBLL:
set__last_changed_by=user_id,
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", True
)
if delete_external_artifacts:
from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
_schedule_for_delete,
)
urls = set()
urls.update(collect_debug_image_urls(company_id, model_id))
urls.update(collect_plot_image_urls(company_id, model_id))
if model.uri:
urls.add(model.uri)
if urls:
_schedule_for_delete(
task_id=model_id,
company=company_id,
user=user_id,
urls=urls,
can_delete_folders=False,
)
if not cls.event_bll:
from apiserver.bll.event import EventBLL
cls.event_bll = EventBLL()
cls.event_bll.delete_task_events(company_id, model_id, allow_locked=True, model=True)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@@ -217,7 +186,7 @@ class ModelBLL:
[
{
"$match": {
"company": {"$in": [None, "", company]},
"company": {"$in": ["", company]},
"_id": {"$in": model_ids},
}
},

View File

@@ -1,4 +1,5 @@
from collections import defaultdict
from datetime import datetime
from enum import Enum
from typing import Sequence, Dict, Type
@@ -28,6 +29,7 @@ class OrgBLL:
def edit_entity_tags(
self,
company_id,
user_id: str,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
@@ -47,13 +49,17 @@ class OrgBLL:
)
updated = 0
last_changed = {
"set__last_change": datetime.utcnow(),
"set__last_changed_by": user_id,
}
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags
add_to_set__tags=add_tags, **last_changed,
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags
pull_all__tags=remove_tags, **last_changed,
)
if not updated:
return 0

View File

@@ -6,7 +6,6 @@ from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model import EntityVisibility
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -43,8 +42,8 @@ class _TagsCache:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project__in=project_ids_with_children([project]))
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
# else:
# query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)

View File

@@ -41,6 +41,7 @@ from .sub_projects import (
_ids_with_parents,
_get_project_depth,
ProjectsChildren,
_get_writable_project_from_name,
)
log = config.logger(__file__)
@@ -169,6 +170,7 @@ class ProjectBLL:
now = datetime.utcnow()
affected = set()
p: Project
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
@@ -183,6 +185,7 @@ class ProjectBLL:
new_name = fields.pop("name", None)
if new_name:
# noinspection PyTypeChecker
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
@@ -225,6 +228,18 @@ class ProjectBLL:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
name, location = _validate_project_name(name)
existing = _get_writable_project_from_name(
company=company,
name=name,
)
if existing:
raise errors.bad_request.ExpectedUniqueData(
replacement_msg="Project with the same name already exists",
name=name,
company=company,
)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -810,7 +825,7 @@ class ProjectBLL:
}
def sum_runtime(
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
a: Mapping[str, dict], b: Mapping[str, dict]
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
@@ -1015,8 +1030,8 @@ class ProjectBLL:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
# else:
# query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
@@ -1046,7 +1061,7 @@ class ProjectBLL:
if not parent_ids:
return []
parents = Task.get_many_with_join(
parents: Sequence[dict] = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
query_dict={"name": name} if name else None,
@@ -1101,7 +1116,7 @@ class ProjectBLL:
project_field: str = "project",
):
conditions = {
"company": {"$in": [None, "", company]},
"company": {"$in": ["", company]},
project_field: {"$in": project_ids},
}
if users:
@@ -1153,7 +1168,7 @@ class ProjectBLL:
if or_conditions:
if len(or_conditions) == 1:
conditions = next(iter(or_conditions))
conditions.update(next(iter(or_conditions)))
else:
conditions["$and"] = [c for c in or_conditions]

View File

@@ -8,10 +8,9 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
TaskUrls,
_schedule_for_delete,
schedule_for_delete,
delete_task_events_and_collect_urls,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
@@ -192,7 +191,7 @@ def delete_project(
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
scheduled = _schedule_for_delete(
scheduled = schedule_for_delete(
task_id=project_id,
company=company,
user=user,
@@ -206,7 +205,6 @@ def delete_project(
deleted_models=deleted_models,
urls=TaskUrls(
model_urls=list(model_urls),
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
),
)
@@ -243,9 +241,6 @@ def _delete_tasks(
last_changed_by=user,
)
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
company, task_ids
)
artifact_urls = set()
for task in tasks:
if task.execution and task.execution.artifacts:
@@ -257,8 +252,11 @@ def _delete_tasks(
}
)
event_bll.delete_multi_task_events(company, task_ids)
event_urls = delete_task_events_and_collect_urls(
company=company, task_ids=task_ids, wait_for_delete=False
)
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
@@ -317,11 +315,10 @@ def _delete_models(
set__last_changed_by=user,
)
event_urls = collect_debug_image_urls(company, model_ids) | collect_plot_image_urls(
company, model_ids
)
model_urls = {m.uri for m in models if m.uri}
event_bll.delete_multi_task_events(company, model_ids, model=True)
event_urls = delete_task_events_and_collect_urls(
company=company, task_ids=model_ids, model=True, wait_for_delete=False
)
deleted = models.delete()
return deleted, event_urls, model_urls

View File

@@ -47,7 +47,7 @@ class ProjectQueries:
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": [None, "", company_id]}}
return {"company": {"$in": ["", company_id]}}
return {"company": company_id}

View File

@@ -2,6 +2,8 @@ import itertools
from datetime import datetime
from typing import Tuple, Optional, Sequence, Mapping
from boltons.iterutils import first
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
@@ -96,10 +98,21 @@ def _get_writable_project_from_name(
"""
Return a project from name. If the project not found then return None
"""
qs = Project.objects(company=company, name=name)
qs = Project.objects(company__in=[company, ""], name=name)
if _only:
if "company" not in _only:
_only = ["company", *_only]
qs = qs.only(*_only)
return qs.first()
projects = list(qs)
if not projects:
return
project = first(p for p in projects if p.company == company)
if not project:
raise errors.bad_request.PublicProjectExists(name=name)
return project
ProjectsChildren = Mapping[str, Sequence[Project]]

View File

@@ -9,20 +9,35 @@ RANGE_IGNORE_VALUE = -1
class Builder:
@staticmethod
def dates_range(from_date: Union[int, float], to_date: Union[int, float]) -> dict:
def dates_range(
from_date: Optional[Union[int, float]] = None,
to_date: Optional[Union[int, float]] = None,
) -> dict:
assert (
from_date or to_date
), "range condition requires that at least one of from_date or to_date specified"
conditions = {}
if from_date:
conditions["gte"] = int(from_date)
if to_date:
conditions["lte"] = int(to_date)
return {
"range": {
"timestamp": {
"gte": int(from_date),
"lte": int(to_date),
**conditions,
"format": "epoch_second",
}
}
}
@staticmethod
def terms(field: str, values: Iterable[str]) -> dict:
def terms(field: str, values: Iterable) -> dict:
if isinstance(values, str):
assert not isinstance(values, str), "apparently 'term' should be used here"
return {"terms": {field: list(values)}}
@staticmethod
def term(field: str, value) -> dict:
return {"term": {field: value}}
@staticmethod
def normalize_range(

View File

@@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime
from typing import Sequence, Optional, Tuple, Union
from typing import Sequence, Optional, Tuple, Union, Iterable
from elasticsearch import Elasticsearch
from mongoengine import Q
@@ -34,6 +34,7 @@ class QueueBLL(object):
def create(
company_id: str,
name: str,
display_name: str = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[dict] = None,
@@ -46,6 +47,7 @@ class QueueBLL(object):
company=company_id,
created=now,
name=name,
display_name=display_name,
tags=tags or [],
system_tags=system_tags or [],
metadata=metadata,
@@ -135,51 +137,78 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
def _update_task_status_on_removal_from_queue(
self,
company_id: str,
user_id: str,
task_ids: Iterable[str],
queue_id: str,
reason: str
) -> Sequence[str]:
from apiserver.bll.task import ChangeStatusRequest
tasks = []
for task_id in task_ids:
try:
task = Task.get(
company=company_id,
id=task_id,
execution__queue=queue_id,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
tasks.append(task.id)
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=reason,
status_message="",
user_id=user_id,
force=True,
).execute(
enqueue_status=None,
unset__execution__queue=1,
)
except Exception as ex:
log.error(
f"Failed updating task {task_id} status on removal from queue: {queue_id}, {str(ex)}"
)
return tasks
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> Sequence[str]:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries:
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
from apiserver.bll.task import ChangeStatusRequest
for item in queue.entries:
try:
task = Task.get(
company=company_id,
id=item.task,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted",
status_message="",
user_id=user_id,
force=True,
).execute(enqueue_status=None)
except Exception as ex:
log.exception(
f"Failed dequeuing task {item.task} from queue: {queue_id}"
)
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if not queue.entries:
queue.delete()
return []
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was deleted",
)
queue.delete()
return tasks
def get_all(
self,
@@ -307,7 +336,36 @@ class QueueBLL(object):
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
def clear_queue(
self,
company_id: str,
user_id: str,
queue_id: str,
):
queue = Queue.objects(company=company_id, id=queue_id).first()
if not queue:
raise errors.bad_request.InvalidQueueId(
queue=queue_id
)
if not queue.entries:
return []
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was cleared",
)
queue.update(entries=[])
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return tasks
def remove_task(self, company_id: str, user_id: str, queue_id: str, task_id: str, update_task_status: bool = False) -> int:
"""
Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
@@ -322,6 +380,14 @@ class QueueBLL(object):
res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
if res and update_task_status:
self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids=[task_id],
queue_id=queue_id,
reason=f"Task was removed from the queue {queue_id}",
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
@@ -461,7 +527,7 @@ class QueueBLL(object):
[
{
"$match": {
"company": {"$in": [None, "", company]},
"company": {"$in": ["", company]},
"_id": queue_id,
}
},

View File

@@ -0,0 +1,376 @@
from datetime import datetime, timedelta, timezone
from enum import Enum, auto
from operator import attrgetter
from time import time
from typing import Optional, Sequence, Union
import attr
from boltons.iterutils import chunked_iter, bucketize
from pyhocon import ConfigTree
from apiserver.apimodels.serving import (
ServingContainerEntry,
RegisterRequest,
StatusReportRequest,
)
from apiserver.apimodels.workers import MachineStats
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.redis_manager import redman
from .stats import ServingStats
log = config.logger(__file__)
class ServingBLL:
def __init__(self, redis=None):
self.conf = config.get("services.serving", ConfigTree())
self.redis = redis or redman.connection("workers")
@staticmethod
def _get_url_key(company: str, url: str):
return f"serving_url_{company}_{url}"
@staticmethod
def _get_container_key(company: str, container_id: str) -> str:
"""Build redis key from company and container_id"""
return f"serving_container_{company}_{container_id}"
def _save_serving_container_entry(self, entry: ServingContainerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
expiration = int(time()) + entry.register_timeout
container_item = {entry.key: expiration}
self.redis.zadd(url_key, container_item)
# make sure that url set will not get stuck in redis
# indefinitely in case no more containers report to it
self.redis.expire(url_key, max(3600, entry.register_timeout))
def _get_serving_container_entry(
self, company_id: str, container_id: str
) -> Optional[ServingContainerEntry]:
"""
Get a container entry for the provided container ID.
"""
key = self._get_container_key(company_id, container_id)
data = self.redis.get(key)
if not data:
return
try:
entry = ServingContainerEntry.from_json(data)
return entry
except Exception as e:
msg = "Failed parsing container entry"
log.exception(f"{msg}: {str(e)}")
def register_serving_container(
self,
company_id: str,
request: RegisterRequest,
ip: str = "",
) -> ServingContainerEntry:
"""
Register a serving container
"""
now = datetime.now(timezone.utc)
key = self._get_container_key(company_id, request.container_id)
entry = ServingContainerEntry(
**request.to_struct(),
key=key,
company_id=company_id,
ip=ip,
register_time=now,
register_timeout=request.timeout,
last_activity_time=now,
)
self._save_serving_container_entry(entry)
return entry
def unregister_serving_container(
self,
company_id: str,
container_id: str,
) -> None:
"""
Unregister a serving container
"""
entry = self._get_serving_container_entry(company_id, container_id)
if entry:
url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
self.redis.zrem(url_key, entry.key)
key = self._get_container_key(company_id, container_id)
res = self.redis.delete(key)
if res:
return
if not self.conf.get("container_auto_unregister", True):
raise errors.bad_request.ContainerNotRegistered(container=container_id)
def container_status_report(
self,
company_id: str,
report: StatusReportRequest,
ip: str = "",
) -> None:
"""
Serving container status report
"""
container_id = report.container_id
now = datetime.now(timezone.utc)
entry = self._get_serving_container_entry(company_id, container_id)
if entry:
ip = ip or entry.ip
register_time = entry.register_time
register_timeout = entry.register_timeout
else:
if not self.conf.get("container_auto_register", True):
raise errors.bad_request.ContainerNotRegistered(container=container_id)
ip = ip
register_time = now
register_timeout = int(
self.conf.get("default_container_timeout_sec", 10 * 60)
)
key = self._get_container_key(company_id, container_id)
entry = ServingContainerEntry(
**report.to_struct(),
key=key,
company_id=company_id,
ip=ip,
register_time=register_time,
register_timeout=register_timeout,
last_activity_time=now,
)
self._save_serving_container_entry(entry)
ServingStats.log_stats_to_es(entry)
def _get_all(
self,
company_id: str,
) -> Sequence[ServingContainerEntry]:
keys = list(self.redis.scan_iter(self._get_container_key(company_id, "*")))
entries = []
for keys in chunked_iter(keys, 1000):
data = self.redis.mget(keys)
if not data:
continue
for d in data:
try:
entries.append(ServingContainerEntry.from_json(d))
except Exception as ex:
log.error(f"Failed parsing container entry {str(ex)}")
return entries
@attr.s(auto_attribs=True)
class Counter:
class AggType(Enum):
avg = auto()
max = auto()
total = auto()
count = auto()
name: str
field: str
agg_type: AggType
float_precision: int = None
_max: Union[int, float, datetime] = attr.field(init=False, default=None)
_total: Union[int, float] = attr.field(init=False, default=0)
_count: int = attr.field(init=False, default=0)
def add(self, entry: ServingContainerEntry):
value = getattr(entry, self.field, None)
if value is None:
return
self._count += 1
if self.agg_type == self.AggType.max:
self._max = value if self._max is None else max(self._max, value)
else:
self._total += value
def __call__(self):
if self.agg_type == self.AggType.count:
return self._count
if self.agg_type == self.AggType.max:
return self._max
if self.agg_type == self.AggType.total:
return self._total
if not self._count:
return None
avg = self._total / self._count
return (
round(avg, self.float_precision) if self.float_precision else round(avg)
)
def _get_summary(self, entries: Sequence[ServingContainerEntry]) -> dict:
counters = [
self.Counter(
name="uptime_sec",
field="uptime_sec",
agg_type=self.Counter.AggType.max,
),
self.Counter(
name="requests",
field="requests_num",
agg_type=self.Counter.AggType.total,
),
self.Counter(
name="requests_min",
field="requests_min",
agg_type=self.Counter.AggType.avg,
float_precision=2,
),
self.Counter(
name="latency_ms",
field="latency_ms",
agg_type=self.Counter.AggType.avg,
),
self.Counter(
name="last_update",
field="last_activity_time",
agg_type=self.Counter.AggType.max,
),
]
for entry in entries:
for counter in counters:
counter.add(entry)
first_entry = entries[0]
ret = {
"endpoint": first_entry.endpoint_name,
"model": first_entry.model_name,
"url": first_entry.endpoint_url,
"instances": len(entries),
**{counter.name: counter() for counter in counters},
}
ret["last_update"] = ret.get("last_update")
return ret
def get_endpoints(self, company_id: str):
"""
Group instances by urls and return a summary for each url
Do not return data for "loading" instances that have no url
"""
entries = self._get_all(company_id)
by_url = bucketize(entries, key=attrgetter("endpoint_url"))
by_url.pop(None, None)
return [self._get_summary(url_entries) for url_entries in by_url.values()]
def _get_endpoint_entries(
self, company_id, endpoint_url: Union[str, None]
) -> Sequence[ServingContainerEntry]:
url_key = self._get_url_key(company_id, endpoint_url)
timestamp = int(time())
self.redis.zremrangebyscore(url_key, min=0, max=timestamp)
container_keys = {key.decode() for key in self.redis.zrange(url_key, 0, -1)}
if not container_keys:
return []
entries = []
found_keys = set()
data = self.redis.mget(container_keys) or []
for d in data:
try:
entry = ServingContainerEntry.from_json(d)
if entry.endpoint_url == endpoint_url:
entries.append(entry)
found_keys.add(entry.key)
except Exception as ex:
log.error(f"Failed parsing container entry {str(ex)}")
missing_keys = container_keys - found_keys
if missing_keys:
self.redis.zrem(url_key, *missing_keys)
return entries
def get_loading_instances(self, company_id: str):
entries = self._get_endpoint_entries(company_id, None)
return [
{
"id": entry.container_id,
"endpoint": entry.endpoint_name,
"url": entry.endpoint_url,
"model": entry.model_name,
"model_source": entry.model_source,
"model_version": entry.model_version,
"preprocess_artifact": entry.preprocess_artifact,
"input_type": entry.input_type,
"input_size": entry.input_size,
"uptime_sec": entry.uptime_sec,
"age_sec": int((datetime.now(timezone.utc) - entry.register_time).total_seconds()),
"last_update": entry.last_activity_time,
}
for entry in entries
]
def get_endpoint_details(self, company_id, endpoint_url: str) -> dict:
entries = self._get_endpoint_entries(company_id, endpoint_url)
if not entries:
raise errors.bad_request.NoContainersForUrl(url=endpoint_url)
instances = []
entry: ServingContainerEntry
for entry in entries:
instances.append(
{
"endpoint": entry.endpoint_name,
"model": entry.model_name,
"url": entry.endpoint_url,
}
)
def get_machine_stats_data(machine_stats: MachineStats) -> dict:
ret = {"cpu_count": 0, "gpu_count": 0}
if not machine_stats:
return ret
for value, field in (
(machine_stats.cpu_usage, "cpu_count"),
(machine_stats.gpu_usage, "gpu_count"),
):
if value is None:
continue
ret[field] = len(value) if isinstance(value, (list, tuple)) else 1
return ret
first_entry = entries[0]
return {
"endpoint": first_entry.endpoint_name,
"model": first_entry.model_name,
"url": first_entry.endpoint_url,
"preprocess_artifact": first_entry.preprocess_artifact,
"input_type": first_entry.input_type,
"input_size": first_entry.input_size,
"model_source": first_entry.model_source,
"model_version": first_entry.model_version,
"uptime_sec": max(e.uptime_sec for e in entries),
"last_update": max(e.last_activity_time for e in entries),
"instances": [
{
"id": entry.container_id,
"uptime_sec": entry.uptime_sec,
"requests": entry.requests_num,
"requests_min": entry.requests_min,
"latency_ms": entry.latency_ms,
"last_update": entry.last_activity_time,
"reference": [ref.to_struct() for ref in entry.reference]
if isinstance(entry.reference, list)
else entry.reference,
**get_machine_stats_data(entry.machine_stats),
}
for entry in entries
],
}

View File

@@ -0,0 +1,340 @@
from collections import defaultdict
from datetime import datetime, timezone
from enum import Enum
from typing import Tuple, Optional, Sequence
from elasticsearch import Elasticsearch
from apiserver.apimodels.serving import (
ServingContainerEntry,
GetEndpointMetricsHistoryRequest,
MetricType,
)
from apiserver.apierrors import errors
from apiserver.utilities.dicts import nested_get
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.es_factory import es_factory
class _AggregationType(Enum):
avg = "avg"
sum = "sum"
class ServingStats:
min_chart_interval = config.get("services.serving.min_chart_interval_sec", 40)
es: Elasticsearch = es_factory.connect("workers")
@classmethod
def _serving_stats_prefix(cls, company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"serving_stats_{company_id.lower()}_"
@staticmethod
def _get_es_index_suffix():
"""Get the index name suffix for storing current month data"""
return datetime.now(timezone.utc).strftime("%Y-%m")
@staticmethod
def _get_average_value(value) -> Tuple[Optional[float], Optional[int]]:
if value is None:
return None, None
if isinstance(value, (list, tuple)):
count = len(value)
if not count:
return None, None
return sum(value) / count, count
return value, 1
@classmethod
def log_stats_to_es(
cls,
entry: ServingContainerEntry,
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: The amount of logged documents
"""
company_id = entry.company_id
es_index = (
f"{cls._serving_stats_prefix(company_id)}" f"{cls._get_es_index_suffix()}"
)
entry_data = entry.to_struct()
doc = {
"timestamp": es_factory.get_timestamp_millis(),
**{
field: entry_data.get(field)
for field in (
"container_id",
"company_id",
"endpoint_url",
"requests_num",
"requests_min",
"uptime_sec",
"latency_ms",
)
},
}
stats = entry_data.get("machine_stats")
if stats:
for category in ("cpu", "gpu"):
usage, num = cls._get_average_value(stats.get(f"{category}_usage"))
doc.update({f"{category}_usage": usage, f"{category}_num": num})
for category in ("memory", "gpu_memory"):
free, _ = cls._get_average_value(stats.get(f"{category}_free"))
used, _ = cls._get_average_value(stats.get(f"{category}_used"))
doc.update(
{
f"{category}_free": free,
f"{category}_used": used,
f"{category}_total": round((free or 0) + (used or 0), 3),
}
)
doc.update(
{
field: stats.get(field)
for field in ("disk_free_home", "network_rx", "network_tx")
}
)
cls.es.index(index=es_index, document=doc)
return 1
@staticmethod
def round_series(values: Sequence, koeff) -> list:
return [round(v * koeff, 2) if v else 0 for v in values]
_mb_to_gb = 1 / 1024
agg_fields = {
MetricType.requests: (
"requests_num",
"Number of Requests",
_AggregationType.sum,
None,
),
MetricType.requests_min: (
"requests_min",
"Requests per Minute",
_AggregationType.sum,
None,
),
MetricType.latency_ms: (
"latency_ms",
"Average Latency (ms)",
_AggregationType.avg,
None,
),
MetricType.cpu_count: ("cpu_num", "CPU Count", _AggregationType.sum, None),
MetricType.gpu_count: ("gpu_num", "GPU Count", _AggregationType.sum, None),
MetricType.cpu_util: (
"cpu_usage",
"Average CPU Load (%)",
_AggregationType.avg,
None,
),
MetricType.gpu_util: (
"gpu_usage",
"Average GPU Utilization (%)",
_AggregationType.avg,
None,
),
MetricType.ram_total: (
"memory_total",
"RAM Total (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.ram_used: (
"memory_used",
"RAM Used (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.ram_free: (
"memory_free",
"RAM Free (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_total: (
"gpu_memory_total",
"GPU RAM Total (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_used: (
"gpu_memory_used",
"GPU RAM Used (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_free: (
"gpu_memory_free",
"GPU RAM Free (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.network_rx: (
"network_rx",
"Network Throughput RX (MBps)",
_AggregationType.sum,
None,
),
MetricType.network_tx: (
"network_tx",
"Network Throughput TX (MBps)",
_AggregationType.sum,
None,
),
}
@classmethod
def get_endpoint_metrics(
cls,
company_id: str,
metrics_request: GetEndpointMetricsHistoryRequest,
) -> dict:
from_date = metrics_request.from_date
to_date = metrics_request.to_date
if from_date >= to_date:
raise errors.bad_request.FieldsValueError(
"from_date must be less than to_date"
)
metric_type = metrics_request.metric_type
agg_data = cls.agg_fields.get(metric_type)
if not agg_data:
raise NotImplemented(f"Charts for {metric_type} not implemented")
agg_field, title, agg_type, multiplier = agg_data
if agg_type == _AggregationType.sum:
instance_sum_type = "sum_bucket"
else:
instance_sum_type = "avg_bucket"
interval = max(metrics_request.interval, cls.min_chart_interval)
endpoint_url = metrics_request.endpoint_url
hist_ret = {
"computed_interval": interval,
"total": {
"title": title,
"dates": [],
"values": [],
},
"instances": {},
}
must_conditions = [
QueryBuilder.term("company_id", company_id),
QueryBuilder.term("endpoint_url", endpoint_url),
QueryBuilder.dates_range(from_date, to_date),
]
query = {"bool": {"must": must_conditions}}
es_index = f"{cls._serving_stats_prefix(company_id)}*"
res = cls.es.search(
index=es_index,
size=0,
query=query,
aggs={"instances": {"terms": {"field": "container_id"}}},
)
instance_buckets = nested_get(res, ("aggregations", "instances", "buckets"))
if not instance_buckets:
return hist_ret
instance_keys = {ib["key"] for ib in instance_buckets}
must_conditions.append(QueryBuilder.terms("container_id", instance_keys))
query = {"bool": {"must": must_conditions}}
sample_func = "avg" if metric_type != MetricType.requests else "max"
aggs = {
"instances": {
"terms": {
"field": "container_id",
"size": max(len(instance_keys), 10),
},
"aggs": {
"sample": {sample_func: {"field": agg_field}},
},
},
"total_instances": {
instance_sum_type: {
"gap_policy": "insert_zeros",
"buckets_path": "instances>sample",
}
},
}
hist_params = {}
if metric_type == MetricType.requests:
hist_params["min_doc_count"] = 1
else:
hist_params["extended_bounds"] = {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
aggs = {
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
**hist_params,
},
"aggs": aggs,
}
}
filter_path = None
if not metrics_request.instance_charts:
filter_path = "aggregations.dates.buckets.total_instances"
data = cls.es.search(
index=es_index,
size=0,
query=query,
aggs=aggs,
filter_path=filter_path,
)
agg_res = data.get("aggregations")
if not agg_res:
return hist_ret
dates_ = []
total = []
instances = defaultdict(list)
# remove last interval if it's incomplete. Allow 10% tolerance
last_valid_timestamp = (to_date - 0.9 * interval) * 1000
for point in agg_res["dates"]["buckets"]:
date_ = point["key"]
if date_ > last_valid_timestamp:
break
dates_.append(date_)
total.append(nested_get(point, ("total_instances", "value"), 0))
if metrics_request.instance_charts:
found_keys = set()
for instance in nested_get(point, ("instances", "buckets"), []):
instances[instance["key"]].append(
nested_get(instance, ("sample", "value"), 0)
)
found_keys.add(instance["key"])
for missing_key in instance_keys - found_keys:
instances[missing_key].append(0)
koeff = multiplier if multiplier else 1.0
hist_ret["total"]["dates"] = dates_
hist_ret["total"]["values"] = cls.round_series(total, koeff)
hist_ret["instances"] = {
key: {
"title": key,
"dates": dates_,
"values": cls.round_series(values, koeff),
}
for key, values in sorted(instances.items(), key=lambda p: p[0])
}
return hist_ret

View File

@@ -1,14 +1,32 @@
import json
import os
import tempfile
from copy import copy
from datetime import datetime
from typing import Optional, Sequence
import attr
from boltons.cacheutils import cachedproperty
from clearml.backend_config.bucket_config import (
S3BucketConfigurations,
AzureContainerConfigurations,
GSBucketConfigurations,
AzureContainerConfig,
GSBucketConfig,
S3BucketConfig,
)
from apiserver.apierrors import errors
from apiserver.apimodels.storage import SetSettingsRequest
from apiserver.config_repo import config
from apiserver.database.model.storage_settings import (
StorageSettings,
GoogleBucketSettings,
AWSSettings,
AzureStorageSettings,
GoogleStorageSettings,
)
from apiserver.database.utils import id as db_id
log = config.logger(__file__)
@@ -32,17 +50,224 @@ class StorageBLL:
def get_azure_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
) -> AzureContainerConfigurations:
return copy(self._default_azure_configs)
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("azure").first()
)
if not db_settings or not db_settings.azure:
return copy(self._default_azure_configs)
azure = db_settings.azure
return AzureContainerConfigurations(
container_configs=[
AzureContainerConfig(**entry.to_proper_dict())
for entry in (azure.containers or [])
]
)
def get_gs_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
json_string: bool = False,
) -> GSBucketConfigurations:
return copy(self._default_gs_configs)
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("google").first()
)
if not db_settings or not db_settings.google:
if not json_string:
return copy(self._default_gs_configs)
if self._default_gs_configs._buckets:
buckets = [
attr.evolve(
b,
credentials_json=self._assure_json_string(b.credentials_json),
)
for b in self._default_gs_configs._buckets
]
else:
buckets = self._default_gs_configs._buckets
return GSBucketConfigurations(
buckets=buckets,
default_project=self._default_gs_configs._default_project,
default_credentials=self._assure_json_string(
self._default_gs_configs._default_credentials
),
)
def get_bucket_config(bc: GoogleBucketSettings) -> GSBucketConfig:
data = bc.to_proper_dict()
if not json_string and bc.credentials_json:
data["credentials_json"] = self._assure_json_file(bc.credentials_json)
return GSBucketConfig(**data)
google = db_settings.google
buckets_configs = [get_bucket_config(b) for b in (google.buckets or [])]
return GSBucketConfigurations(
buckets=buckets_configs,
default_project=google.project,
default_credentials=google.credentials_json
if json_string
else self._assure_json_file(google.credentials_json),
)
def get_aws_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
) -> S3BucketConfigurations:
return copy(self._default_aws_configs)
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("aws").first()
)
if not db_settings or not db_settings.aws:
return copy(self._default_aws_configs)
aws = db_settings.aws
buckets_configs = S3BucketConfig.from_list(
[b.to_proper_dict() for b in (aws.buckets or [])]
)
return S3BucketConfigurations(
buckets=buckets_configs,
default_key=aws.key,
default_secret=aws.secret,
default_region=aws.region,
default_use_credentials_chain=aws.use_credentials_chain,
default_token=aws.token,
default_extra_args={},
)
def _assure_json_file(self, name_or_content: str) -> str:
if not name_or_content:
return name_or_content
if name_or_content.endswith(".json") or os.path.exists(name_or_content):
return name_or_content
try:
json.loads(name_or_content)
except Exception:
return name_or_content
with tempfile.NamedTemporaryFile(
mode="wt", delete=False, suffix=".json"
) as tmp:
tmp.write(name_or_content)
return tmp.name
def _assure_json_string(self, name_or_content: str) -> Optional[str]:
if not name_or_content:
return name_or_content
try:
json.loads(name_or_content)
return name_or_content
except Exception:
pass
try:
with open(name_or_content) as fp:
return fp.read()
except Exception:
return None
def get_company_settings(self, company_id: str) -> dict:
db_settings = StorageSettings.objects(company=company_id).first()
aws = self.get_aws_settings_for_company(company_id, db_settings, query_db=False)
aws_dict = {
"key": aws._default_key,
"secret": aws._default_secret,
"token": aws._default_token,
"region": aws._default_region,
"use_credentials_chain": aws._default_use_credentials_chain,
"buckets": [attr.asdict(b) for b in aws._buckets],
}
gs = self.get_gs_settings_for_company(
company_id, db_settings, query_db=False, json_string=True
)
gs_dict = {
"project": gs._default_project,
"credentials_json": gs._default_credentials,
"buckets": [attr.asdict(b) for b in gs._buckets],
}
azure = self.get_azure_settings_for_company(company_id, db_settings)
azure_dict = {
"containers": [attr.asdict(ac) for ac in azure._container_configs],
}
return {
"aws": aws_dict,
"google": gs_dict,
"azure": azure_dict,
"last_update": db_settings.last_update if db_settings else None,
}
def set_company_settings(
self, company_id: str, settings: SetSettingsRequest
) -> int:
update_dict = {}
if settings.aws:
update_dict["aws"] = {
**{
k: v
for k, v in settings.aws.to_struct().items()
if k in AWSSettings.get_fields()
}
}
if settings.azure:
update_dict["azure"] = {
**{
k: v
for k, v in settings.azure.to_struct().items()
if k in AzureStorageSettings.get_fields()
}
}
if settings.google:
update_dict["google"] = {
**{
k: v
for k, v in settings.google.to_struct().items()
if k in GoogleStorageSettings.get_fields()
}
}
cred_json = update_dict["google"].get("credentials_json")
if cred_json:
try:
json.loads(cred_json)
except Exception as ex:
raise errors.bad_request.ValidationError(
f"Invalid json credentials: {str(ex)}"
)
if not update_dict:
raise errors.bad_request.ValidationError("No settings were provided")
settings = StorageSettings.objects(company=company_id).only("id").first()
settings_id = settings.id if settings else db_id()
return StorageSettings.objects(id=settings_id).update(
upsert=True,
id=settings_id,
company=company_id,
last_update=datetime.utcnow(),
**update_dict,
)
def reset_company_settings(self, company_id: str, keys: Sequence[str]) -> int:
return StorageSettings.objects(company=company_id).update(
last_update=datetime.utcnow(), **{f"unset__{k}": 1 for k in keys}
)

View File

@@ -193,7 +193,7 @@ class HyperParams:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"company": {"$in": ["", company_id]},
"_id": {"$in": task_ids},
}
},

View File

@@ -39,6 +39,7 @@ from apiserver.database.utils import (
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.utilities.dicts import nested_set
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@@ -163,18 +164,36 @@ class TaskBLL:
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
hyperparams_overrides: Optional[dict] = None,
configuration_overrides: Optional[dict] = None,
) -> Tuple[Task, dict]:
validate_tags(tags, system_tags)
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
task: Task = cls.get_by_id(
company_id=company_id, task_id=task_id, allow_public=True
)
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
params_dict = {}
if hyperparams:
params_dict["hyperparams"] = hyperparams
elif hyperparams_overrides:
updated_hyperparams = {
sec: {k: value for k, value in sec_data.items()}
for sec, sec_data in (task.hyperparams or {}).items()
}
for section, section_data in hyperparams_overrides.items():
for key, value in section_data.items():
nested_set(updated_hyperparams, (section, key), value)
params_dict["hyperparams"] = updated_hyperparams
if configuration:
params_dict["configuration"] = configuration
elif configuration_overrides:
updated_configuration = {
k: value for k, value in (task.configuration or {}).items()
}
for key, value in configuration_overrides.items():
updated_configuration[key] = value
params_dict["configuration"] = updated_configuration
now = datetime.utcnow()
if input_models:
@@ -389,7 +408,7 @@ class TaskBLL:
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
:param last_scalar_events: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:
@@ -439,8 +458,13 @@ class TaskBLL:
return ret
@staticmethod
def remove_task_from_all_queues(company_id: str, task_id: str) -> int:
return Queue.objects(company=company_id, entries__task=task_id).update(
def remove_task_from_all_queues(
company_id: str, task_id: str, exclude: str = None
) -> int:
more = {}
if exclude:
more["id__ne"] = exclude
return Queue.objects(company=company_id, entries__task=task_id, **more).update(
pull__entries__task=task_id, last_update=datetime.utcnow()
)
@@ -454,9 +478,10 @@ class TaskBLL:
status_reason: str,
remove_from_all_queues=False,
new_status=None,
new_status_for_aborted_task=None,
):
try:
cls.dequeue(task, company_id, silent_fail=True)
cls.dequeue(task, company_id=company_id, user_id=user_id, silent_fail=True)
except APIError:
# dequeue may fail if the queue was deleted
pass
@@ -467,6 +492,9 @@ class TaskBLL:
if task.status not in [TaskStatus.queued, TaskStatus.in_progress]:
return {"updated": 0}
if new_status_for_aborted_task and task.status == TaskStatus.in_progress:
new_status = new_status_for_aborted_task
return ChangeStatusRequest(
task=task,
new_status=new_status or task.enqueue_status or TaskStatus.created,
@@ -477,7 +505,7 @@ class TaskBLL:
).execute(enqueue_status=None)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
def dequeue(cls, task: Task, company_id: str, user_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
@@ -504,6 +532,9 @@ class TaskBLL:
return {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
company_id=company_id,
user_id=user_id,
queue_id=task.execution.queue,
task_id=task.id,
)
}

View File

@@ -31,8 +31,8 @@ event_bll = EventBLL()
@attr.s(auto_attribs=True)
class TaskUrls:
model_urls: Sequence[str]
event_urls: Sequence[str]
artifact_urls: Sequence[str]
event_urls: Sequence[str] = [] # left here is in order not to break the api
def __add__(self, other: "TaskUrls"):
if not other:
@@ -40,7 +40,6 @@ class TaskUrls:
return TaskUrls(
model_urls=list(set(self.model_urls) | set(other.model_urls)),
event_urls=list(set(self.event_urls) | set(other.event_urls)),
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
)
@@ -54,8 +53,23 @@ class CleanupResult:
updated_children: int
updated_models: int
deleted_models: int
deleted_model_ids: Set[str]
urls: TaskUrls = None
def to_res_dict(self, return_file_urls: bool) -> dict:
remove_fields = ["deleted_model_ids"]
if not return_file_urls:
remove_fields.append("urls")
# noinspection PyTypeChecker
res = attr.asdict(
self, filter=lambda attrib, value: attrib.name not in remove_fields
)
if not return_file_urls:
res["urls"] = None
return res
def __add__(self, other: "CleanupResult"):
if not other:
return self
@@ -65,6 +79,16 @@ class CleanupResult:
updated_models=self.updated_models + other.updated_models,
deleted_models=self.deleted_models + other.deleted_models,
urls=self.urls + other.urls if self.urls else other.urls,
deleted_model_ids=self.deleted_model_ids | other.deleted_model_ids,
)
@staticmethod
def empty():
return CleanupResult(
updated_children=0,
updated_models=0,
deleted_models=0,
deleted_model_ids=set(),
)
@@ -130,7 +154,7 @@ supported_storage_types.update(
)
def _schedule_for_delete(
def schedule_for_delete(
company: str,
user: str,
task_id: str,
@@ -197,15 +221,27 @@ def _schedule_for_delete(
return processed_urls
def delete_task_events_and_collect_urls(
company: str, task_ids: Sequence[str], wait_for_delete: bool, model=False
) -> Set[str]:
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
company, task_ids
)
event_bll.delete_task_events(
company, task_ids, model=model, wait_for_delete=wait_for_delete
)
return event_urls
def cleanup_task(
company: str,
user: str,
task: Task,
force: bool = False,
update_children=True,
return_file_urls=False,
delete_output_models=True,
delete_external_artifacts=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
@@ -216,22 +252,16 @@ def cleanup_task(
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
task, force
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", True
)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls or delete_external_artifacts:
event_urls = collect_debug_image_urls(task.company, task.id)
event_urls.update(collect_plot_image_urls(task.company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls = {
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
model_urls = {
m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
artifact_urls = (
{
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
if task.execution and task.execution.artifacts
else {}
)
model_urls = {m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids}
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
@@ -245,22 +275,15 @@ def cleanup_task(
deleted_models = 0
updated_models = 0
deleted_model_ids = set()
for models, allow_delete in ((draft_models, True), (published_models, False)):
if not models:
continue
if delete_output_models and allow_delete:
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
if model_ids:
if return_file_urls or delete_external_artifacts:
event_urls.update(collect_debug_image_urls(task.company, model_ids))
event_urls.update(collect_plot_image_urls(task.company, model_ids))
event_bll.delete_multi_task_events(
task.company,
model_ids,
model=True,
)
deleted_models += Model.objects(id__in=model_ids).delete()
deleted_model_ids.update(model_ids)
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(
@@ -283,30 +306,15 @@ def cleanup_task(
set__last_changed_by=user,
)
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
if delete_external_artifacts:
scheduled = _schedule_for_delete(
task_id=task.id,
company=company,
user=user,
urls=event_urls | model_urls | artifact_urls,
can_delete_folders=not in_use_model_ids and not published_models,
)
for urls in (event_urls, model_urls, artifact_urls):
urls.difference_update(scheduled)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
urls=TaskUrls(
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
model_urls=list(model_urls),
)
if return_file_urls
else None,
),
deleted_model_ids=deleted_model_ids,
)

View File

@@ -22,7 +22,8 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION, TaskType,
DEFAULT_LAST_ITERATION,
TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
@@ -85,6 +86,7 @@ def archive_task(
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
new_status_for_aborted_task=TaskStatus.stopped,
)
except APIError:
# dequeue may fail if the task was not enqueued
@@ -99,7 +101,9 @@ def archive_task(
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
archive_task_core(step)
@@ -136,7 +140,9 @@ def unarchive_task(
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
unarchive_task_core(step)
@@ -204,12 +210,25 @@ def enqueue_task(
queue_name: str = None,
validate: bool = False,
force: bool = False,
update_execution_queue: bool = True,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
if not update_execution_queue:
if not (
task.status == TaskStatus.queued and task.execution and task.execution.queue
):
raise errors.bad_request.ValidationError(
"Cannot skip setting execution queue for a task "
"that is not enqueued or does not have execution queue set"
)
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
@@ -222,23 +241,21 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
user_id = identity.user
if validate:
TaskBLL.validate(task)
before_enqueue_status = task.status
if task.status == TaskStatus.queued and task.enqueue_status:
before_enqueue_status = task.enqueue_status
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
force=force,
user_id=user_id,
).execute(enqueue_status=task.status)
).execute(enqueue_status=before_enqueue_status)
try:
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
@@ -255,12 +272,19 @@ def enqueue_task(
raise
# set the current queue ID in the task
if task.execution:
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
if update_execution_queue:
if task.execution:
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(id=task_id).update(
execution=Execution(queue=queue_id), multi=False
)
nested_set(res, ("fields", "execution.queue"), queue_id)
nested_set(res, ("fields", "execution.queue"), queue_id)
# make sure that the task is not queued in any other queue
TaskBLL.remove_task_from_all_queues(
company_id=company_id, task_id=task_id, exclude=queue_id
)
return 1, res
@@ -294,17 +318,13 @@ def delete_task(
identity: Identity,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
delete_external_artifacts: bool,
include_pipeline_steps: bool,
) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if (
task.status != TaskStatus.created
@@ -318,7 +338,7 @@ def delete_task(
current=task.status,
)
def delete_task_core(task_: Task, force_: bool):
def delete_task_core(task_: Task, force_: bool) -> CleanupResult:
try:
TaskBLL.dequeue_and_change_status(
task_,
@@ -337,9 +357,7 @@ def delete_task(
user=user_id,
task=task_,
force=force_,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
@@ -353,11 +371,12 @@ def delete_task(
return res
task_ids = [task.id]
cleanup_res = CleanupResult.empty()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
delete_task_core(step, True)
cleanup_res += delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
@@ -373,15 +392,11 @@ def reset_task(
company_id: str,
identity: Identity,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@@ -390,7 +405,9 @@ def reset_task(
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
dequeued = TaskBLL.dequeue(
task, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
@@ -403,9 +420,7 @@ def reset_task(
task=task,
force=force,
update_children=False,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
updates.update(
@@ -466,9 +481,7 @@ def publish_task(
status_reason: str = "",
) -> dict:
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force:
validate_status_change(task.status, TaskStatus.published)
@@ -566,7 +579,9 @@ def stop_task(
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True)
TaskBLL.dequeue(
task_, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
@@ -587,7 +602,9 @@ def stop_task(
).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
stop_task_core(step, True)

View File

@@ -144,7 +144,12 @@ state_machine = {
TaskStatus.publishing,
TaskStatus.stopped,
},
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
TaskStatus.failed: {
TaskStatus.created,
TaskStatus.stopped,
TaskStatus.published,
TaskStatus.queued,
},
TaskStatus.publishing: {TaskStatus.published},
TaskStatus.published: set(),
TaskStatus.completed: {
@@ -177,7 +182,7 @@ def get_many_tasks_for_writing(
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
missing = [f for f in ("company", ) if f not in only]
missing = [f for f in ("company",) if f not in only]
if missing:
only = [*only, *missing]
@@ -230,7 +235,7 @@ def get_task_for_update(
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
force: bool = False
force: bool = False,
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
@@ -286,13 +291,62 @@ def get_last_metric_updates(
new_metrics = []
def add_last_metric_mean_update(
metric_path: str,
metric_count: int,
metric_total: float,
):
"""
Update new mean field based on the value in db and new data
The count field is updated here too and not with inc__ so that
it will not get updated in the db earlier than the corresponding mean
"""
metric_path = metric_path.replace("__", ".")
mean_value_field = f"{metric_path}.mean_value"
count_field = f"{metric_path}.count"
raw_updates[mean_value_field] = {
"$round": [
{
"$divide": [
{
"$add": [
{
"$multiply": [
{"$ifNull": [f"${mean_value_field}", 0]},
{"$ifNull": [f"${count_field}", 0]},
]
},
metric_total,
]
},
{
"$add": [
{"$ifNull": [f"${count_field}", 0]},
metric_count,
]
},
]
},
2,
]
}
raw_updates[count_field] = {
"$add": [
{"$ifNull": [f"${count_field}", 0]},
metric_count,
]
}
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
metric_path: str, metric_value, iter_value: int, is_min: bool, is_first: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
if is_first:
field_prefix = "first"
op = None
elif is_min:
field_prefix = "min"
op = "$gt"
else:
@@ -300,18 +354,23 @@ def get_last_metric_updates(
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
exists = {"$lte": [f"${value_field}", None]}
if op:
condition = {
"$or": [
exists,
{op: [f"${value_field}", metric_value]},
]
}
else:
condition = exists
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
value_iteration_field = (
f"{metric_path}__{field_prefix}_value_iteration".replace("__", ".")
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
@@ -328,15 +387,25 @@ def get_last_metric_updates(
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
if key in ("min_value", "max_value", "first_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
is_first=(key == "first_value"),
)
elif key in ("metric", "variant", "value"):
elif key in ("metric", "variant", "value", "x_axis_label"):
extra_updates[f"set__{path}__{key}"] = value
count = variant_data.get("count")
total = variant_data.get("total")
if count is not None and total is not None:
add_last_metric_mean_update(
metric_path=path,
metric_count=count,
metric_total=total,
)
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from apiserver.apierrors import errors
from apiserver.apimodels.users import CreateRequest
from apiserver.config.info import get_version
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.user import User
@@ -14,7 +15,11 @@ class UserBLL:
if user_id and User.objects(id=user_id).only("id"):
raise errors.bad_request.UserIdExists(id=user_id)
user = User(**request.to_struct(), created=datetime.utcnow())
user = User(
**request.to_struct(),
created=datetime.utcnow(),
created_in_version=get_version(),
)
user.save(force_insert=True)
@staticmethod

View File

@@ -297,6 +297,7 @@ class WorkerBLL:
{
"$project": {
"name": 1,
"display_name": 1,
"next_entry": {"$arrayElemAt": ["$entries", 0]},
"num_entries": {"$size": "$entries"},
}
@@ -330,6 +331,7 @@ class WorkerBLL:
if not info:
continue
entry.name = info.get("name", None)
entry.display_name = info.get("display_name", None)
entry.num_tasks = info.get("num_entries", 0)
task_id = nested_get(info, ("next_entry", "task"))
if task_id:

View File

@@ -73,9 +73,13 @@ class WorkerStats:
Buckets with no metrics are not returned
Note: all the statistics are retrieved as one ES query
"""
if request.from_date >= request.to_date:
from_date = request.from_date
to_date = request.to_date
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
interval = max(request.interval, self.min_chart_interval)
def get_dates_agg() -> dict:
es_to_agg_types = (
("avg", AggregationType.avg.value),
@@ -87,8 +91,11 @@ class WorkerStats:
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{request.interval}s",
"min_doc_count": 1,
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
},
"aggs": {
agg_type: {es_agg: {"field": "value"}}
@@ -120,7 +127,7 @@ class WorkerStats:
}
query_terms = [
QueryBuilder.dates_range(request.from_date, request.to_date),
QueryBuilder.dates_range(from_date, to_date),
QueryBuilder.terms("metric", {item.key for item in request.items}),
]
if request.worker_ids:
@@ -130,16 +137,16 @@ class WorkerStats:
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
return self._extract_results(data, request.items, request.split_by_variant)
cutoff_date = (to_date - 0.9 * interval) * 1000 # do not return the point for the incomplete last interval
return self._extract_results(data, request.items, request.split_by_variant, cutoff_date)
@staticmethod
def _extract_results(
data: dict, request_items: Sequence[StatItem], split_by_variant: bool
data: dict, request_items: Sequence[StatItem], split_by_variant: bool, cutoff_date
) -> dict:
"""
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
leave only aggregation types requested by the user and return a clean dictionary
and return a "clean" dictionary of
:param data: aggregation data retrieved from ES
:param request_items: aggs types requested by the user
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
@@ -157,7 +164,7 @@ class WorkerStats:
return {
"date": date["key"],
"count": date["doc_count"],
**{agg: date[agg]["value"] for agg in aggs_per_metric[metric_key]},
**{agg: date[agg]["value"] or 0.0 for agg in aggs_per_metric[metric_key]},
}
def extract_metric_results(
@@ -166,7 +173,7 @@ class WorkerStats:
return [
extract_date_stats(date, metric_key)
for date in metric_or_variant["dates"]["buckets"]
if date["doc_count"]
if date["key"] <= cutoff_date
]
def extract_variant_results(metric: dict) -> dict:

View File

@@ -0,0 +1,7 @@
default_container_timeout_sec: 600
# Auto-register unknown serving containers on status reports and other calls
container_auto_register: true
# Assume unknow serving containers have unregistered (i.e. do not raise unregistered error)
container_auto_unregister: true
# The minimal sampling interval for serving model monitor chars
min_chart_interval_sec: 40

View File

@@ -37,6 +37,8 @@ OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
class DatabaseEntry(models.Base):
host = StringField(required=True)
alias = StringField()
name = StringField()
db = StringField()
class DatabaseFactory:
@@ -78,10 +80,13 @@ class DatabaseFactory:
missing.append(key)
continue
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
settings = {**db_entries.get(key)}
if not any(field in settings for field in ("name", "db")):
settings["name"] = key
entry = cls._create_db_entry(alias=alias, settings=settings)
if override_connection_string:
con_str = f"{override_connection_string.rstrip('/')}/{key}"
con_str = override_connection_string
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
entry.host = con_str
else:

View File

@@ -1,5 +1,5 @@
import re
from collections import namedtuple, defaultdict
from collections import defaultdict
from datetime import datetime
from functools import reduce, partial
from typing import (
@@ -107,7 +107,18 @@ class GetMixin(PropsMixin):
("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b,
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
@attr.s(auto_attribs=True)
class MultiFieldParameters:
fields: Sequence[str]
pattern: str = None
datetime: Union[list, str] = None
def __attrs_post_init__(self):
if not any(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Either 'pattern' or 'datetime' should be provided")
if all(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Only one of the 'pattern' and 'datetime' can be provided")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
@@ -323,6 +334,8 @@ class GetMixin(PropsMixin):
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
- <any|all>: {fields: [<field1>, <field2>, ...], datetime: <datetime condition>} Will query for items where any or all
provided datetime fields match the provided condition.
:return: mongoengine.Q query object
"""
return cls._prepare_query_no_company(
@@ -376,6 +389,46 @@ class GetMixin(PropsMixin):
return cls._try_convert_to_numeric(value)
return value
@classmethod
def _get_dates_query(cls, field: str, data: Union[list, str]) -> Union[Q, dict]:
"""
Return dates query for the field
If the data is 2 values array and none of the values starts from dates comparison operations
then return the simplified range query
Otherwise return the dictionary of dates conditions
"""
if not isinstance(data, list):
data = [data]
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
return cls.get_range_field_query(field, data)
dict_query = {}
for d in data:
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
return dict_query
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@@ -446,33 +499,11 @@ class GetMixin(PropsMixin):
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
# date time fields also support simplified range queries. Check if this is the case
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
query &= cls.get_range_field_query(field, data)
else:
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
dates_q = cls._get_dates_query(field, data)
if isinstance(dates_q, Q):
query &= dates_q
elif isinstance(dates_q, dict):
dict_query.update(dates_q)
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
@@ -484,27 +515,40 @@ class GetMixin(PropsMixin):
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
if data.pattern is not None:
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
),
data.fields,
RegexQ(),
)
data.fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
date_fields = [field for field in data.fields if field in opts.datetime_fields]
if not date_fields:
break
q = Q()
for date_f in date_fields:
dates_q = cls._get_dates_query(date_f, data.datetime)
if isinstance(dates_q, dict):
dates_q = RegexQ(**dates_q)
q = func(q, dates_q)
query = query & q
except APIError:
raise
@@ -1394,7 +1438,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
id__in=ids, company="", company_origin=company_id
).only("id")
)
update: dict = dict(set__company=company_id, unset__company_origin=1)

View File

@@ -37,10 +37,18 @@ class Model(AttributedDocument):
"project",
"task",
"last_update",
("company", "framework"),
("company", "last_update"),
("company", "name"),
("company", "user"),
("company", "uri"),
# distinct queries support
("company", "tags"),
("company", "system_tags"),
("company", "project", "tags"),
("company", "project", "system_tags"),
("company", "user"),
("company", "project", "user"),
("company", "framework"),
("company", "project", "framework"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@@ -71,8 +79,8 @@ class Model(AttributedDocument):
"parent",
"metadata.*",
),
range_fields=("last_metrics.*", "last_iteration"),
datetime_fields=("last_update",),
range_fields=("created", "last_metrics.*", "last_iteration"),
datetime_fields=("last_update", "last_change"),
)
id = StringField(primary_key=True)

View File

@@ -47,6 +47,7 @@ class Queue(DbModelMixin, Document):
name = StrippedStringField(
required=True, unique_with="company", min_length=3, user_set_allowed=True
)
display_name = StringField(user_set_allowed=True)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(

View File

@@ -0,0 +1,76 @@
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
BooleanField,
)
from apiserver.database import Database, strict
from apiserver.database.model import DbModelMixin
from apiserver.database.model.base import ProperDictMixin
class AWSBucketSettings(EmbeddedDocument, ProperDictMixin):
bucket = StringField()
subdir = StringField()
host = StringField()
key = StringField()
secret = StringField()
token = StringField()
multipart = BooleanField()
acl = StringField()
secure = BooleanField()
region = StringField()
verify = BooleanField()
use_credentials_chain = BooleanField()
class AWSSettings(EmbeddedDocument, DbModelMixin):
key = StringField()
secret = StringField()
region = StringField()
token = StringField()
use_credentials_chain = BooleanField()
buckets = EmbeddedDocumentListField(AWSBucketSettings)
class GoogleBucketSettings(EmbeddedDocument, ProperDictMixin):
bucket = StringField()
subdir = StringField()
project = StringField()
credentials_json = StringField()
class GoogleStorageSettings(EmbeddedDocument, DbModelMixin):
project = StringField()
credentials_json = StringField()
buckets = EmbeddedDocumentListField(GoogleBucketSettings)
class AzureStorageContainerSettings(EmbeddedDocument, ProperDictMixin):
account_name = StringField(required=True)
account_key = StringField(required=True)
container_name = StringField()
class AzureStorageSettings(EmbeddedDocument, DbModelMixin):
containers = EmbeddedDocumentListField(AzureStorageContainerSettings)
class StorageSettings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"company"
],
}
id = StringField(primary_key=True)
company = StringField(required=True, unique=True)
last_update = DateTimeField()
aws: AWSSettings = EmbeddedDocumentField(AWSSettings)
google: GoogleStorageSettings = EmbeddedDocumentField(GoogleStorageSettings)
azure: AzureStorageSettings = EmbeddedDocumentField(AzureStorageSettings)

View File

@@ -5,6 +5,7 @@ from mongoengine import (
LongField,
EmbeddedDocumentField,
IntField,
FloatField,
)
from apiserver.database.fields import SafeMapField
@@ -23,6 +24,11 @@ class MetricEvent(EmbeddedDocument):
min_value_iteration = IntField()
max_value = DynamicField() # for backwards compatibility reasons
max_value_iteration = IntField()
first_value = FloatField()
first_value_iteration = IntField()
count = IntField()
mean_value = FloatField()
x_axis_label = StringField()
class EventStats(EmbeddedDocument):

View File

@@ -183,9 +183,8 @@ class Task(AttributedDocument):
"status_changed",
"models.input.model",
("company", "name"),
("company", "user"),
("company", "status", "type"),
("company", "system_tags", "last_update"),
("company", "last_update", "system_tags"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
@@ -193,6 +192,17 @@ class Task(AttributedDocument):
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
# distinct queries support
("company", "tags"),
("company", "system_tags"),
("company", "project", "tags"),
("company", "project", "system_tags"),
("company", "user"),
("company", "project", "user"),
("company", "parent"),
("company", "project", "parent"),
("company", "type"),
("company", "project", "type"),
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
@@ -233,8 +243,8 @@ class Task(AttributedDocument):
"execution.queue",
"models.input.model",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update", "last_change"),
pattern_fields=("name", "comment", "report"),
fields=("runtime.*",),
)

View File

@@ -20,4 +20,5 @@ class User(DbModelMixin, Document):
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = DynamicField(default="", exclude_by_default=True)
created_in_version = StringField()
created = DateTimeField()

View File

@@ -121,8 +121,8 @@ def init_cls_from_base(cls, instance):
)
def get_company_or_none_constraint(company=None):
return Q(company__in=(company, None, "")) | Q(company__exists=False)
def get_company_or_none_constraint(company=""):
return Q(company__in=list({company, ""}))
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:

View File

@@ -2,6 +2,7 @@
| Release | ApiVersion |
|---------|------------|
| v1.17 | 2.31 |
| v1.16 | 2.30 |
| v1.15 | 2.29 |
| v1.14 | 2.28 |

View File

@@ -0,0 +1,79 @@
{
"index_patterns": "serving_stats_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"container_id": {
"type": "keyword"
},
"company_id": {
"type": "keyword"
},
"endpoint_url": {
"type": "keyword"
},
"requests_num": {
"type": "integer"
},
"requests_min": {
"type": "float"
},
"uptime_sec": {
"type": "integer"
},
"latency_ms": {
"type": "integer"
},
"cpu_usage": {
"type": "float"
},
"cpu_num": {
"type": "integer"
},
"gpu_usage": {
"type": "float"
},
"gpu_num": {
"type": "integer"
},
"memory_used": {
"type": "float"
},
"memory_free": {
"type": "float"
},
"memory_total": {
"type": "float"
},
"gpu_memory_used": {
"type": "float"
},
"gpu_memory_free": {
"type": "float"
},
"gpu_memory_total": {
"type": "float"
},
"disk_free_home": {
"type": "float"
},
"network_rx": {
"type": "float"
},
"network_tx": {
"type": "float"
}
}
}
}
}

122
apiserver/fix_mongo_urls.py Normal file
View File

@@ -0,0 +1,122 @@
import logging
from argparse import (
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
)
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
logging.getLogger().setLevel(logging.INFO)
def fix_mongo_urls(mongo_host: str, host_source: str, host_target: str):
logging.info(f"Connecting to Mongo on {mongo_host}")
client = MongoClient(host=mongo_host)
backend_db: Database = client.backend
def get_updated_uri(uri: str):
if not uri or not uri.startswith(host_source):
return
relative_url = uri[len(host_source) :]
return f"{host_target.rstrip('/')}/{relative_url.lstrip('/')}"
host_source = host_source
host_target = host_target
model_collection: Collection = backend_db.get_collection("model")
if model_collection is not None:
logging.info("Updating model uris")
models_count = model_collection.count_documents({})
updated_models = 0
for model in model_collection.find(
{"uri": {"$regex": "^{}".format(host_source)}}, projection=["uri"]
):
updated_uri = get_updated_uri(model.get("uri"))
if updated_uri:
result = model_collection.update_one(
{"_id": model["_id"]}, {"$set": {"uri": updated_uri}}
)
updated_models += result.modified_count
logging.info(f"Updated {updated_models} models from {models_count}")
task_collection: Collection = backend_db.get_collection("task")
if task_collection is not None:
logging.info("Updating task uris")
tasks_count = task_collection.count_documents({})
updated_tasks = 0
for task in task_collection.find(
{"execution.artifacts": {"$exists": 1, "$ne": {}}},
projection=["execution.artifacts"],
):
artifacts = task.get("execution", {}).get("artifacts")
if not artifacts:
continue
uri_updated = False
for artifact in artifacts.values():
updated_uri = get_updated_uri(artifact.get("uri"))
if updated_uri:
artifact["uri"] = updated_uri
uri_updated = True
if uri_updated:
result = task_collection.update_one(
{"_id": task["_id"]}, {"$set": {"execution.artifacts": artifacts}}
)
updated_tasks += result.modified_count
logging.info(f"Updated {updated_tasks} tasks from {tasks_count}")
def normalise_host(host):
if not host.endswith("/"):
return host
return host[:-1]
def main():
def valid_url_prefix(url: str):
if "://" not in url:
raise ArgumentTypeError("url schema is missing")
return url
parser = ArgumentParser(
description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--mongo-host",
"-mh",
type=str,
default="mongodb://mongo:27017",
help="Mongo server host. The default is mongodb://mongo:27017",
)
parser.add_argument(
"--host-source",
"-hs",
type=valid_url_prefix,
required=True,
help="Source host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
)
parser.add_argument(
"--host-target",
"-ht",
type=valid_url_prefix,
required=True,
help="Target host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
)
args = parser.parse_args()
fix_mongo_urls(
mongo_host=args.mongo_host,
host_source=args.host_source,
host_target=args.host_target,
)
logging.info("Completed successfully")
if __name__ == "__main__":
main()

View File

@@ -8,13 +8,16 @@ import pymongo.database
from mongoengine.connection import get_db
from packaging.version import Version, parse
from apiserver.config_repo import config
from apiserver.database import utils
from apiserver.database import Database
from apiserver.database.model.version import Version as DatabaseVersion
from apiserver.utilities.dicts import nested_get
_migrations = "migrations"
_parent_dir = Path(__file__).resolve().parents[1]
_migration_dir = _parent_dir / _migrations
log = config.logger(__file__)
def check_mongo_empty() -> bool:
@@ -41,6 +44,26 @@ def get_last_server_version() -> Version:
return previous_versions[0] if previous_versions else Version("0.0.0")
def _ensure_mongodb_version():
db: pymongo.database.Database = get_db(Database.backend)
db_version = db.client.server_info()["version"]
if not db_version.startswith("6.0"):
log.warning(f"Database version should be 6.0.x. Instead: {str(db_version)}")
return
res = db.client.admin.command({"getParameter": 1, "featureCompatibilityVersion": 1})
version = nested_get(res, ("featureCompatibilityVersion", "version"))
if version == "6.0":
return
if version != "5.0":
log.warning(f"Cannot upgrade DB version. Should be 5.0. {str(res)}")
return
log.info("Upgrading db version from 5.0 to 6.0")
res = db.client.admin.command({"setFeatureCompatibilityVersion": "6.0"})
log.info(res)
def _apply_migrations(log: Logger):
"""
Apply migrations as found in the migration dir.
@@ -50,6 +73,8 @@ def _apply_migrations(log: Logger):
log.info(f"Started mongodb migrations")
_ensure_mongodb_version()
if not _migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {_migration_dir}")

View File

@@ -28,10 +28,11 @@ from urllib.parse import unquote, urlparse
from uuid import uuid4, UUID, uuid5
from zipfile import ZipFile, ZIP_BZIP2
import attr
import mongoengine
from boltons.iterutils import chunked_iter, first
from furl import furl
from mongoengine import Q
from mongoengine import Q, Document
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
@@ -61,6 +62,8 @@ from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
replace_s3_scheme = os.getenv("CLEARML_REPLACE_S3_SCHEME")
class PrePopulate:
module_name_prefix = "apiserver."
@@ -84,6 +87,11 @@ class PrePopulate:
user_cls: Type[User]
auth_user_cls: Type[AuthUser]
@attr.s(auto_attribs=True)
class ParentPrefix:
prefix: str
path: Sequence[str]
# noinspection PyTypeChecker
@classmethod
def _init_entity_types(cls):
@@ -469,20 +477,35 @@ class PrePopulate:
@classmethod
def _check_projects_hierarchy(cls, projects: Set[Project]):
"""
For any exported project all its parents up to the root should be present
For the projects that are exported not from the root
fix their parents tree to exclude the not exported parents
"""
if not projects:
return
project_ids = {p.id for p in projects}
orphans = [p.id for p in projects if p.parent and p.parent not in project_ids]
orphans = [p for p in projects if p.parent and p.parent not in project_ids]
if not orphans:
return
print(
f"ERROR: the following projects are exported without their parents: {orphans}"
)
exit(1)
prefixes = [
cls.ParentPrefix(prefix=f"{project.name.rpartition('/')[0]}/", path=project.path)
for project in orphans
]
prefixes.sort(key=lambda p: len(p.path), reverse=True)
for project in projects:
prefix = first(pref for pref in prefixes if project.path[:len(pref.path)] == pref.path)
if not prefix:
continue
project.path = project.path[len(prefix.path):]
if not project.path:
project.parent = None
project.name = project.name.removeprefix(prefix.prefix)
# print(
# f"ERROR: the following projects are exported without their parents: {orphans}"
# )
# exit(1)
@classmethod
def _resolve_entities(
@@ -491,6 +514,7 @@ class PrePopulate:
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
# noinspection PyTypeChecker
entities: Dict[Any] = defaultdict(set)
if projects:
@@ -539,6 +563,7 @@ class PrePopulate:
print("Reading models...")
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
# noinspection PyTypeChecker
return entities
@classmethod
@@ -643,8 +668,9 @@ class PrePopulate:
@staticmethod
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
if not (url and url.lower().startswith("s3://")):
if not (replace_s3_scheme and url and url.lower().startswith("s3://")):
return url
try:
fixed = furl(url)
fixed.scheme = "https"
@@ -983,8 +1009,10 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _upgrade_project_data(project_data: dict) -> dict:
@classmethod
def _upgrade_project_data(cls, project_data: dict) -> dict:
cls._remove_incompatible_fields(cls.project_cls, project_data)
if not project_data.get("basename"):
name: str = project_data["name"]
_, _, basename = name.rpartition("/")
@@ -992,8 +1020,10 @@ class PrePopulate:
return project_data
@staticmethod
def _upgrade_model_data(model_data: dict) -> dict:
@classmethod
def _upgrade_model_data(cls, model_data: dict) -> dict:
cls._remove_incompatible_fields(cls.model_cls, model_data)
metadata_key = "metadata"
metadata = model_data.get(metadata_key)
if isinstance(metadata, list):
@@ -1006,7 +1036,13 @@ class PrePopulate:
return model_data
@staticmethod
def _upgrade_task_data(task_data: dict) -> dict:
def _remove_incompatible_fields(cls_: Type[Document], data: dict):
for field in ("company_origin",):
if field not in cls_._db_field_map:
data.pop(field, None)
@classmethod
def _upgrade_task_data(cls, task_data: dict) -> dict:
"""
Migrate from execution/parameters and model_desc to hyperparams and configuration fiields
Upgrade artifacts list to dict
@@ -1015,6 +1051,8 @@ class PrePopulate:
:param task_data: Upgraded in place
:return: The upgraded task data
"""
cls._remove_incompatible_fields(cls.task_cls, task_data)
for old_param_field, new_param_field, default_section in (
("execution.parameters", "hyperparams", hyperparams_default_section),
("execution.model_desc", "configuration", None),
@@ -1133,7 +1171,7 @@ class PrePopulate:
if isinstance(doc, cls.task_cls):
tasks.append(doc)
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
cls.event_bll.delete_task_events(company_id, doc.id, wait_for_delete=True)
if tasks:
return tasks

View File

@@ -6,7 +6,7 @@ boto3>=1.26
boto3-stubs[s3]>=1.26
clearml>=1.10.3
dpath>=1.4.2,<2.0
elasticsearch==8.12.0
elasticsearch==8.17.0
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
@@ -19,15 +19,15 @@ jinja2
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.27.0
mongoengine==0.29.1
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35r
pyjwt>=2.4.0
pymongo==4.6.3
pymongo==4.10.1
python-rapidjson>=0.6.3
redis>=4.5.4,<5
redis==5.2.1
requests>=2.13.0
semantic_version>=2.8.3,<3
setuptools>=65.5.1

View File

@@ -74,7 +74,11 @@ multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
description: "Pattern string (regex). Either 'pattern' or 'datetime' should be specified"
type: string
}
datetime {
description: "Date time conditions (applicable only to datetime fields). Either 'pattern' or 'datetime' should be specified"
type: string
}
fields {

View File

@@ -283,6 +283,26 @@ last_metrics_event {
description: "The iteration at which the maximum value was reported"
type: integer
}
first_value {
description: "First value reported"
type: number
}
first_value_iteration {
description: "The iteration at which the first value was reported"
type: integer
}
mean_value {
description: "The mean value"
type: number
}
count {
description: "The total count of reported values"
type: integer
}
x_axis_label {
description: The user defined value for the X-Axis name stored with the event
type: string
}
}
}
last_metrics_variants {

View File

@@ -0,0 +1,67 @@
machine_stats {
type: object
properties {
cpu_usage {
description: "Average CPU usage per core"
type: array
items { type: number }
}
gpu_usage {
description: "Average GPU usage per GPU card"
type: array
items { type: number }
}
memory_used {
description: "Used memory MBs"
type: number
}
memory_free {
description: "Free memory MBs"
type: number
}
gpu_memory_free {
description: "GPU free memory MBs"
type: array
items { type: number }
}
gpu_memory_used {
description: "GPU used memory MBs"
type: array
items { type: number }
}
network_tx {
description: "Mbytes per second"
type: number
}
network_rx {
description: "Mbytes per second"
type: number
}
disk_free_home {
description: "Free space in % of /home drive"
type: number
}
disk_free_temp {
description: "Free space in % of /tmp drive"
type: number
}
disk_read {
description: "Mbytes read per second"
type: number
}
disk_write {
description: "Mbytes write per second"
type: number
}
cpu_temperature {
description: "CPU temperature"
type: array
items { type: number }
}
gpu_temperature {
description: "GPU temperature"
type: array
items { type: number }
}
}
}

View File

@@ -27,13 +27,17 @@ _definitions {
type: string
}
variant {
description: "E.g. 'class_1', 'total', 'average"
description: "E.g. 'class_1', 'total', 'average'"
type: string
}
value {
description: ""
type: number
}
x_axis_label {
description: "Custom X-Axis label to be used when displaying the scalars histogram"
type: string
}
}
}
metrics_vector_event {

View File

@@ -1,20 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_tasks_common.conf"
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
model {
type: object
properties {

View File

@@ -1,20 +1,6 @@
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
project {
type: object
properties {

View File

@@ -50,6 +50,10 @@ _definitions {
description: "Queue name"
type: string
}
display_name {
description: "Display name"
type: string
}
user {
description: "Associated user id"
type: string
@@ -324,7 +328,7 @@ create {
}
}
"2.13": ${create."2.4"} {
metadata {
request.properties.metadata {
description: "Queue metadata"
type: object
additionalProperties {
@@ -332,6 +336,12 @@ create {
}
}
}
"2.31": ${create."2.13"} {
request.properties.display_name {
description: "Display name"
type: string
}
}
}
update {
"2.4" {
@@ -377,7 +387,7 @@ update {
}
}
"2.13": ${update."2.4"} {
metadata {
request.properties.metadata {
description: "Queue metadata"
type: object
additionalProperties {
@@ -385,6 +395,12 @@ update {
}
}
}
"2.31": ${update."2.13"} {
request.properties.display_name {
description: "Display name"
type: string
}
}
}
delete {
"2.4" {
@@ -447,6 +463,13 @@ add_task {
}
}
}
"2.31": ${add_task."2.4"} {
request.properties.update_execution_queue {
description: If set to false then the task 'execution.queue' is not updated
type: boolean
default: true
}
}
}
get_next_task {
"2.4" {
@@ -530,8 +553,41 @@ remove_task {
}
}
}
"2.31": ${remove_task."2.4"} {
request.properties {
update_task_status {
type: boolean
default: false
description: If set to 'true' then change the removed task status to the one it had prior to enqueuing or 'created'
}
}
}
}
clear_queue {
"2.31" {
description: Remove all tasks from the queue and change their statuses to what they were prior to enqueuing or 'created'
request {
type: object
required: [queue]
properties {
queue {
description: "Queue id"
type: string
}
}
}
response {
type: object
properties {
removed_tasks {
description: IDs of the removed tasks
type: array
items {type: string}
}
}
}
}
}
move_task_forward: {
"2.4" {
description: "Moves a task entry one step forward towards the top of the queue."

View File

@@ -0,0 +1,437 @@
_description: "Serving apis"
_definitions {
include "_workers_common.conf"
reference_item {
type: object
required = [type, value]
properties {
type {
description: The type of the reference item
type: string
enum: [app_id, app_instance, model, task, url]
}
value {
description: The reference item value
type: string
}
}
}
reference {
description: Array of reference items provided by the container instance. Can contain multiple reference items with the same type
type: array
items: ${_definitions.reference_item}
}
serving_model_report {
type: object
required: [container_id, endpoint_name, model_name]
properties {
container_id {
type: string
description: Container ID. Should uniquely identify a specific container instance
}
endpoint_name {
type: string
description: Endpoint name
}
endpoint_url {
type: string
description: Endpoint URL
}
model_name {
type: string
description: Model name
}
model_source {
type: string
description: Model source
}
model_version {
type: string
description: Model version
}
preprocess_artifact {
type: string
description: Preprocess Artifact
}
input_type {
type: string
description: Input type
}
input_size {
type: string
description: Input size
}
reference: ${_definitions.reference}
}
}
endpoint_stats {
type: object
properties {
endpoint {
type: string
description: Endpoint name
}
model {
type: string
description: Model name
}
url {
type: string
description: Model url
}
instances {
type: integer
description: The number of model serving instances
}
uptime_sec {
type: integer
description: Max of model instance uptime in seconds
}
requests {
type: integer
description: Total requests processed by model instances
}
requests_min {
type: number
description: Average of request rate of model instances per minute
}
latency_ms {
type: integer
description: Average of latency of model instances in ms
}
last_update {
type: string
format: "date-time"
description: The latest time when one of the model instances was updated
}
}
}
container_instance_stats {
type: object
properties {
id {
type: string
description: Container ID
}
uptime_sec {
type: integer
description: Uptime in seconds
}
requests {
type: integer
description: Number of requests
}
requests_min {
type: number
description: Average requests per minute
}
latency_ms {
type: integer
description: Average request latency in ms
}
last_update {
type: string
format: "date-time"
description: The latest time when the container instance sent update
}
cpu_count {
type: integer
description: CPU Count
}
gpu_count {
type: integer
description: GPU Count
}
reference: ${_definitions.reference}
}
}
serving_model_info {
type: object
properties {
endpoint {
type: string
description: Endpoint name
}
model {
type: string
description: Model name
}
url {
type: string
description: Model url
}
model_source {
type: string
description: Model source
}
model_version {
type: string
description: Model version
}
preprocess_artifact {
type: string
description: Preprocess Artifact
}
input_type {
type: string
description: Input type
}
input_size {
type: string
description: Input size
}
}
}
container_info: ${_definitions.serving_model_info} {
properties {
id {
type: string
description: Container ID
}
uptime_sec {
type: integer
description: Model instance uptime in seconds
}
last_update {
type: string
format: "date-time"
description: The latest time when the container instance sent update
}
age_sec {
type: integer
description: Amount of seconds since the container registration
}
}
}
metrics_history_series {
type: object
properties {
title {
type: string
description: "The title of the series"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval."
items {type: integer}
}
values {
type: array
description: "List of values corresponding to the timestamps in the dates list."
items {type: number}
}
}
}
}
register_container {
"2.31" {
description: Register container
request = ${_definitions.serving_model_report} {
properties {
timeout {
description: "Registration timeout in seconds. If timeout seconds have passed since the service container last call to register or status_report, the container is automatically removed from the list of registered containers."
type: integer
default: 600
}
}
}
response {
type: object
additionalProperties: false
}
}
}
unregister_container {
"2.31" {
description: Unregister container
request {
type: object
required: [container_id]
properties {
container_id {
type: string
description: Container ID
}
}
}
response {
type: object
additionalProperties: false
}
}
}
container_status_report {
"2.31" {
description: Container status report
request = ${_definitions.serving_model_report} {
properties {
uptime_sec {
type: integer
description: Uptime in seconds
}
requests_num {
type: integer
description: Number of requests
}
requests_min {
type: number
description: Average requests per minute
}
latency_ms {
type: integer
description: Average request latency in ms
}
machine_stats {
description: "The machine statistics"
"$ref": "#/definitions/machine_stats"
}
}
}
response {
type: object
additionalProperties: false
}
}
}
get_endpoints {
"2.31" {
description: Get all the registered endpoints
request {
type: object
additionalProperties: false
}
response {
type: object
properties {
endpoints {
type: array
items { "$ref": "#/definitions/endpoint_stats" }
}
}
}
}
}
get_loading_instances {
"2.31" {
description: "Get loading instances (enpoint_url not set yet)"
request {
type: object
additionalProperties: false
}
response {
type: object
properties {
instances {
type: array
items { "$ref": "#/definitions/container_info" }
}
}
}
}
}
get_endpoint_details {
"2.31" {
description: Get endpoint details
request {
type: object
required: [endpoint_url]
properties {
endpoint_url {
type: string
description: Endpoint URL
}
}
}
response: ${_definitions.serving_model_info} {
properties {
uptime_sec {
type: integer
description: Max of model instance uptime in seconds
}
last_update {
type: string
format: "date-time"
description: The latest time when one of the model instances was updated
}
instances {
type: array
items {"$ref": "#/definitions/container_instance_stats"}
}
}
}
}
}
get_endpoint_metrics_history {
"2.31" {
description: Get endpoint charts
request {
type: object
required: [endpoint_url, from_date, to_date, interval]
properties {
endpoint_url {
description: Endpoint Url
type: string
}
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
metric_type {
description: The type of the metrics to return on the chart
type: string
default: requests
enum: [
requests
requests_min
latency_ms
cpu_count
gpu_count
cpu_util
gpu_util
ram_total
ram_used
ram_free
gpu_ram_total
gpu_ram_used
gpu_ram_free
network_rx
network_tx
]
}
instance_charts {
type: boolean
default: true
description: If set then return instance charts and total. Otherwise total only
}
}
}
response {
type: object
properties {
computed_interval {
description: The inteval that was actually used for the histogram. May be larger then the requested one
type: integer
}
total: ${_definitions.metrics_history_series} {
properties {
description: The total histogram
}
}
instances {
description: Instance charts
type: object
additionalProperties: ${_definitions.metrics_history_series}
}
}
}
}
}

View File

@@ -0,0 +1,242 @@
_description: """This service provides storage settings managmement"""
_default {
internal: true
}
_definitions {
include "_common.conf"
aws_bucket {
type: object
description: Settings per S3 bucket
properties {
bucket {
description: The name of the bucket
type: string
}
subdir {
description: The path to match
type: string
}
host {
description: Host address (for minio servers)
type: string
}
key {
description: Access key
type: string
}
secret {
description: Secret key
type: string
}
token {
description: Access token
type: string
}
multipart {
description: Multipart upload
type: boolean
default: true
}
acl {
description: ACL
type: string
}
secure {
description: Use SSL connection
type: boolean
default: true
}
region {
description: AWS Region
type: string
}
verify {
description: Verify server certificate
type: boolean
default: true
}
use_credentials_chain {
description: Use host configured credentials
type: boolean
default: false
}
}
}
aws {
type: object
description: AWS S3 storage settings
properties {
key {
description: Access key
type: string
}
secret {
description: Secret key
type: string
}
region {
description: AWS region
type: string
}
token {
description: Access token
type: string
}
use_credentials_chain {
description: If set then use host credentials
type: boolean
default: false
}
buckets {
description: Credential settings per bucket
type: array
items {"$ref": "#/definitions/aws_bucket"}
}
}
}
google_bucket {
type: object
description: Settings per Google storage bucket
properties {
bucket {
description: The name of the bucket
type: string
}
project {
description: The name of the project
type: string
}
subdir {
description: The path to match
type: string
}
credentials_json {
description: The contents of the credentials json file
type: string
}
}
}
google {
type: object
description: Google storage settings
properties {
project {
description: Project name
type: string
}
credentials_json {
description: The contents of the credentials json file
type: string
}
buckets {
description: Credentials per bucket
type: array
items {"$ref": "#/definitions/google_bucket"}
}
}
}
azure_container {
type: object
description: Azure container settings
properties {
account_name {
description: Account name
type: string
}
account_key {
description: Account key
type: string
}
container_name {
description: The name of the container
type: string
}
}
}
azure {
type: object
description: Azure storage settings
properties {
containers {
description: Credentials per container
type: array
items {"$ref": "#/definitions/azure_container"}
}
}
}
}
set_settings {
"2.31" {
description: Set Storage settings
request {
type: object
properties {
aws {"$ref": "#/definitions/aws"}
google {"$ref": "#/definitions/google"}
azure {"$ref": "#/definitions/azure"}
}
}
response {
type: object
properties {
updated {
description: "Number of settings documents updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
reset_settings {
"2.31" {
description: Reset selected storage settings
request {
type: object
properties {
keys {
description: The names of the settings to delete
type: array
items {
type: string
enum: ["azure", "aws", "google"]
}
}
}
}
response {
type: object
properties {
updated {
description: "Number of settings documents updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
get_settings {
"2.22" {
description: Get storage settings
request {
type: object
additionalProperties: false
}
response {
type: object
properties {
last_update {
description: "Settings last update time (UTC) "
type: string
format: "date-time"
}
aws {"$ref": "#/definitions/aws"}
google {"$ref": "#/definitions/google"}
azure {"$ref": "#/definitions/azure"}
}
}
}
}

View File

@@ -1507,6 +1507,13 @@ Fails if the following parameters in the task were not filled:
type: boolean
}
}
"2.31": ${enqueue."2.22"} {
request.properties.update_execution_queue {
description: If set to false then the task 'execution.queue' is not updated. This can be done only for the task that is already enqueued
type: boolean
default: true
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {

View File

@@ -147,7 +147,7 @@ get_current_user {
description: Getting stated info
additionalProperties: true
}
created {
user.properties.created {
type: string
description: User creation time
format: date-time
@@ -166,6 +166,14 @@ get_current_user {
}
}
}
"2.31": ${get_current_user."2.26"} {
response.properties {
user.properties.created_in_version {
type: string
description: Server version at user creation time
}
}
}
}
get_all_ex {

View File

@@ -1,5 +1,6 @@
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
include "_workers_common.conf"
metrics_category {
type: object
properties {
@@ -193,6 +194,10 @@ _definitions {
queue_entry = ${_definitions.id_name_entry} {
properties {
display_name {
description: "Display name for the queue (if defined)"
type: string
}
next_task {
description: "Next task in the queue"
"$ref": "#/definitions/id_name_entry"
@@ -203,74 +208,6 @@ _definitions {
}
}
}
machine_stats {
type: object
properties {
cpu_usage {
description: "Average CPU usage per core"
type: array
items { type: number }
}
gpu_usage {
description: "Average GPU usage per GPU card"
type: array
items { type: number }
}
memory_used {
description: "Used memory MBs"
type: integer
}
memory_free {
description: "Free memory MBs"
type: integer
}
gpu_memory_free {
description: "GPU free memory MBs"
type: array
items { type: integer }
}
gpu_memory_used {
description: "GPU used memory MBs"
type: array
items { type: integer }
}
network_tx {
description: "Mbytes per second"
type: integer
}
network_rx {
description: "Mbytes per second"
type: integer
}
disk_free_home {
description: "Mbytes free space of /home drive"
type: integer
}
disk_free_temp {
description: "Mbytes free space of /tmp drive"
type: integer
}
disk_read {
description: "Mbytes read per second"
type: integer
}
disk_write {
description: "Mbytes write per second"
type: integer
}
cpu_temperature {
description: "CPU temperature"
type: array
items { type: number }
}
gpu_temperature {
description: "GPU temperature"
type: array
items { type: number }
}
}
}
}
get_all {
"2.4" {

View File

@@ -2,9 +2,11 @@ import unicodedata
import urllib.parse
from functools import partial
from boltons.iterutils import first
from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
from werkzeug.http import quote_header_value
from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError
@@ -21,12 +23,26 @@ log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
_basic_cookie_settings = config.get("apiserver.auth.cookies")
_custom_cookie_settings = {
c["name"]: c["settings"]
for c in config.get("apiserver.auth.custom_cookies", {}).values()
if c.get("enabled") and c.get("settings")
}
def _get_cookie_settings(self, cookie_key=None):
settings = (
self._custom_cookie_settings.get(cookie_key) or self._basic_cookie_settings
).copy()
if isinstance(settings["domain"], list):
host_without_port, _, _ = request.host.partition(":")
domain = first(
settings["domain"],
key=lambda d: host_without_port.endswith(d) if d else False,
)
settings["domain"] = domain
return settings
def before_request(self):
if request.method == "OPTIONS":
return "", 200
@@ -54,17 +70,18 @@ class RequestHandlers:
if call.result.filename:
# make sure that downloaded files are not cached by the client
disable_cache = True
download_name = call.result.filename
try:
call.result.filename.encode("ascii")
download_name.encode("ascii")
except UnicodeEncodeError:
simple = unicodedata.normalize("NFKD", call.result.filename)
simple = unicodedata.normalize("NFKD", download_name)
simple = simple.encode("ascii", "ignore").decode("ascii")
# safe = RFC 5987 attr-char
quoted = urllib.parse.quote(call.result.filename, safe="")
filenames = f"filename={simple}; filename*=UTF-8''{quoted}"
quoted = urllib.parse.quote(download_name, safe="")
filenames = f"filename={quote_header_value(simple)}; filename*=UTF-8''{quoted}"
else:
filenames = f"filename={call.result.filename}"
headers = {"Content-Disposition": "attachment; " + filenames}
filenames = f"filename={quote_header_value(download_name)}"
headers = {f"Content-Disposition": f"attachment; {filenames}"}
response = Response(
content,
@@ -78,10 +95,7 @@ class RequestHandlers:
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = (
self._custom_cookie_settings.get(key)
or config.get("apiserver.auth.cookies")
).copy()
kwargs = self._get_cookie_settings(key)
if value is None:
# Removing a cookie
kwargs["max_age"] = 0

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.30")
_max_version = PartialVersion("2.31")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -42,6 +42,7 @@ from apiserver.apimodels.events import (
LegacyMultiTaskEventsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import LOCKED_TASK_STATUSES
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
@@ -52,6 +53,7 @@ from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL()
@@ -488,6 +490,7 @@ def scalar_metrics_iter_histogram(
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
)
call.result.data = metrics
@@ -538,12 +541,13 @@ def multi_task_scalar_metrics_iter_histogram(
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
)
)
def _get_single_value_metrics_response(
companies: TaskCompanies, value_metrics: Mapping[str, dict]
companies: TaskCompanies, value_metrics: Mapping[str, Sequence[dict]]
) -> Sequence[dict]:
task_names = {
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
@@ -1001,30 +1005,64 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
def _validate_task_for_events_update(
company_id: str, task_id: str, identity: Identity, allow_locked: bool
):
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=identity,
only=("id", "status"),
)
if not allow_locked and task.status in LOCKED_TASK_STATUSES:
raise errors.bad_request.InvalidTaskId(
replacement_msg="Cannot update events for a published task",
company=company_id,
id=task_id,
)
@endpoint("events.delete_for_task")
def delete_for_task(call, company_id, request: TaskRequest):
task_id = request.task
allow_locked = call.data.get("allow_locked", False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
_validate_task_for_events_update(
company_id=company_id,
task_id=task_id,
identity=call.identity,
allow_locked=allow_locked,
)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked
)
deleted=event_bll.delete_task_events(company_id, task_id, wait_for_delete=True)
)
def _validate_model_for_events_update(
company_id: str, model_id: str, allow_locked: bool
):
model = model_bll.assert_exists(company_id, model_id, only=("id", "ready"))[0]
if not allow_locked and model.ready:
raise errors.bad_request.InvalidModelId(
replacement_msg="Cannot update events for a published model",
company=company_id,
id=model_id,
)
@endpoint("events.delete_for_model")
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
model_id = request.model
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)
_validate_model_for_events_update(
company_id=company_id, model_id=model_id, allow_locked=allow_locked
)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, model_id, allow_locked=allow_locked, model=True
company_id, model_id, model=True, wait_for_delete=True
)
)
@@ -1033,14 +1071,17 @@ def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
_validate_task_for_events_update(
company_id=company_id,
task_id=task_id,
identity=call.identity,
allow_locked=request.allow_locked,
)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,
task_id=task_id,
allow_locked=request.allow_locked,
threshold_sec=request.threshold_sec,
exclude_metrics=request.exclude_metrics,
include_metrics=request.include_metrics,

View File

@@ -27,12 +27,17 @@ from apiserver.apimodels.models import (
UpdateModelRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_cleanup import (
schedule_for_delete,
delete_task_events_and_collect_urls,
)
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.task.utils import get_task_with_write_access, deleted_prefix
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.model import validate_id
@@ -64,6 +69,7 @@ from apiserver.services.utils import (
log = config.logger(__file__)
org_bll = OrgBLL()
project_bll = ProjectBLL()
event_bll = EventBLL()
def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
@@ -182,7 +188,12 @@ def get_all(call: APICall, company_id, _):
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
call.result.data = {
"frameworks": sorted(
project_bll.get_model_frameworks(company_id, project_ids=request.projects)
filter(
None,
project_bll.get_model_frameworks(
company_id, project_ids=request.projects
),
)
)
}
@@ -216,6 +227,9 @@ last_update_fields = (
def parse_model_fields(call, valid_fields):
task_id = call.data.get("task")
if isinstance(task_id, str) and task_id.startswith(deleted_prefix):
call.data.pop("task")
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
@@ -555,16 +569,67 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
)
def _delete_model_events(
company_id: str,
user_id: str,
models: Sequence[Model],
delete_external_artifacts: bool,
sync_delete: bool,
):
if not models:
return
model_ids = [m.id for m in models]
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", True
)
if delete_external_artifacts:
model_urls = {m.uri for m in models if m.uri}
if model_urls:
schedule_for_delete(
task_id=model_ids[0],
company=company_id,
user=user_id,
urls=model_urls,
can_delete_folders=False,
)
event_urls = delete_task_events_and_collect_urls(
company=company_id,
task_ids=model_ids,
model=True,
wait_for_delete=sync_delete,
)
if event_urls:
schedule_for_delete(
task_id=model_ids[0],
company=company_id,
user=user_id,
urls=event_urls,
can_delete_folders=False,
)
event_bll.delete_task_events(
company_id, model_ids, model=True, wait_for_delete=sync_delete
)
@endpoint("models.delete", request_data_model=DeleteModelRequest)
def delete(call: APICall, company_id, request: DeleteModelRequest):
user_id = call.identity.user
del_count, model = ModelBLL.delete_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
user_id=user_id,
force=request.force,
delete_external_artifacts=request.delete_external_artifacts,
)
if del_count:
_delete_model_events(
company_id=company_id,
user_id=user_id,
models=[model],
delete_external_artifacts=request.delete_external_artifacts,
sync_delete=True,
)
_reset_cached_tags(
company_id, projects=[model.project] if model.project else []
)
@@ -577,27 +642,38 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
request_data_model=ModelsDeleteManyRequest,
response_data_model=BatchResponse,
)
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
def delete_many(call: APICall, company_id, request: ModelsDeleteManyRequest):
user_id = call.identity.user
results, failures = run_batch_operation(
func=partial(
ModelBLL.delete_model,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
if results:
projects = set(model.project for _, (_, model) in results)
succeeded = []
deleted_models = []
for _id, (deleted, model) in results:
succeeded.append(dict(id=_id, deleted=bool(deleted), url=model.uri))
deleted_models.append(model)
if deleted_models:
_delete_model_events(
company_id=company_id,
user_id=user_id,
models=deleted_models,
delete_external_artifacts=request.delete_external_artifacts,
sync_delete=False,
)
projects = set(model.project for model in deleted_models)
_reset_cached_tags(company_id, projects=list(projects))
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, deleted=bool(deleted), url=model.uri)
for _id, (deleted, model) in results
],
succeeded=succeeded,
failed=failures,
)
@@ -684,10 +760,11 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("models.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
user_id=call.identity.user,
entity_cls=Model,
entity_ids=request.ids,
add_tags=request.add_tags,

View File

@@ -301,7 +301,7 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque
future = pool.submit(get_fn, page, min(page_size, items_left))
with StringIO() as fp:
writer = csv.writer(fp)
writer = csv.writer(fp, quoting=csv.QUOTE_NONNUMERIC)
if page == 1:
fp.write("\ufeff") # utf-8 signature
writer.writerow(field_mappings)

View File

@@ -1,8 +1,6 @@
import re
from functools import partial
import attr
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
from apiserver.apimodels.pipelines import (
StartPipelineRequest,
@@ -18,6 +16,7 @@ from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities.dicts import nested_get
from .tasks import _delete_task_events
org_bll = OrgBLL()
project_bll = ProjectBLL()
@@ -62,21 +61,31 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
identity=call.identity,
move_to_trash=False,
force=True,
return_file_urls=False,
delete_output_models=True,
status_message="",
status_reason="Pipeline run deleted",
delete_external_artifacts=True,
include_pipeline_steps=True,
),
ids=list(ids),
)
succeeded = []
tasks = {}
if results:
for _id, (deleted, task, cleanup_res) in results:
if deleted:
tasks[_id] = cleanup_res
succeeded.append(
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
dict(id=_id, deleted=bool(deleted), **cleanup_res.to_res_dict(False))
)
if tasks:
_delete_task_events(
company_id=company_id,
user_id=call.identity.user,
tasks=tasks,
delete_external_artifacts=True,
sync_delete=True,
)
call.result.data = dict(succeeded=succeeded, failed=failures)
@@ -101,7 +110,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=hyperparams,
hyperparams_overrides=hyperparams,
)
_update_task_name(task)

View File

@@ -370,6 +370,7 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
delete_external_artifacts=request.delete_external_artifacts,
)
_reset_cached_tags(company_id, projects=list(affected_projects))
# noinspection PyTypeChecker
call.result.data = {**attr.asdict(res)}

View File

@@ -20,13 +20,15 @@ from apiserver.apimodels.queues import (
GetNextTaskRequest,
GetByIdRequest,
GetAllRequest,
AddTaskRequest,
RemoveTaskRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.queue.queue_bll import MOVE_FIRST, MOVE_LAST
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.database.model.task.task import Task
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
@@ -46,7 +48,7 @@ def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
unescape_metadata(call, queue_data)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
@endpoint("queues.get_by_id", min_version="2.4")
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
@@ -111,7 +113,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest):
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
@endpoint("queues.create", min_version="2.4")
def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags(
call, request.tags, request.system_tags, validate=True
@@ -119,6 +121,7 @@ def create(call: APICall, company_id, request: CreateRequest):
queue = queue_bll.create(
company_id=company_id,
name=request.name,
display_name=request.display_name,
tags=tags,
system_tags=system_tags,
metadata=Metadata.metadata_from_api(request.metadata),
@@ -129,60 +132,82 @@ def create(call: APICall, company_id, request: CreateRequest):
@endpoint(
"queues.update",
min_version="2.4",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
)
def update(call: APICall, company_id, req_model: UpdateRequest):
def update(call: APICall, company_id, request: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data, validate=True)
escape_metadata(data)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
company_id=company_id, queue_id=request.queue, **data
)
conform_queue_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest):
@endpoint("queues.delete", min_version="2.4")
def delete(call: APICall, company_id, request: DeleteRequest):
queue_bll.delete(
company_id=company_id,
user_id=call.identity.user,
queue_id=req_model.queue,
force=req_model.force,
queue_id=request.queue,
force=request.force,
)
call.result.data = {"deleted": 1}
@endpoint("queues.add_task", min_version="2.4", request_data_model=TaskRequest)
def add_task(call: APICall, company_id, req_model: TaskRequest):
call.result.data = {
"added": queue_bll.add_task(
company_id=company_id, queue_id=req_model.queue, task_id=req_model.task
@endpoint("queues.add_task", min_version="2.4")
def add_task(call: APICall, company_id, request: AddTaskRequest):
added = queue_bll.add_task(
company_id=company_id, queue_id=request.queue, task_id=request.task
)
if added and request.update_execution_queue:
Task.objects(id=request.task).update(
execution__queue=request.queue, multi=False
)
}
call.result.data = {"added": added}
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
@endpoint("queues.get_next_task")
def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=request.queue, task_id=request.task
)
if entry:
data = {"entry": entry.to_proper_dict()}
if request.get_task_info:
task = Task.objects(id=entry.task).only("company", "user").first()
if task:
task = Task.objects(id=entry.task).only("company", "user", "status").first()
if task:
# fix racing condition that can result in the task being aborted
# by an agent after it was already placed in a queue
if task.status == TaskStatus.stopped:
task.update(status=TaskStatus.queued)
if request.get_task_info:
data["task_info"] = {"company": task.company, "user": task.user}
call.result.data = data
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
def remove_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint("queues.remove_task", min_version="2.4")
def remove_task(call: APICall, company_id, request: RemoveTaskRequest):
call.result.data = {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=req_model.queue, task_id=req_model.task
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
task_id=request.task,
update_task_status=request.update_task_status,
)
}
@endpoint("queues.clear_queue")
def clear_queue(call: APICall, company_id, request: QueueRequest):
call.result.data = {
"removed_tasks": queue_bll.clear_queue(
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
)
}
@@ -190,16 +215,15 @@ def remove_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.move_task_forward",
min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
def move_task_forward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
move_count=-req_model.count,
queue_id=request.queue,
task_id=request.task,
move_count=-request.count,
)
)
@@ -207,16 +231,15 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint(
"queues.move_task_backward",
min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
def move_task_backward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
move_count=req_model.count,
queue_id=request.queue,
task_id=request.task,
move_count=request.count,
)
)
@@ -224,15 +247,14 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint(
"queues.move_task_to_front",
min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
def move_task_to_front(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
queue_id=request.queue,
task_id=request.task,
move_count=MOVE_FIRST,
)
)
@@ -241,15 +263,14 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.move_task_to_back",
min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
def move_task_to_back(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
queue_id=request.queue,
task_id=request.task,
move_count=MOVE_LAST,
)
)
@@ -258,7 +279,6 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.get_queue_metrics",
min_version="2.4",
request_data_model=GetMetricsRequest,
response_data_model=GetMetricsResponse,
)
def get_queue_metrics(

View File

@@ -282,6 +282,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
metric_variants=_get_metric_variants_from_request(
request.scalar_metrics_iter_histogram.metrics
),
model_events=request.model_events,
)
if request.single_value_metrics:

View File

@@ -0,0 +1,69 @@
from apiserver.apimodels.serving import (
RegisterRequest,
UnregisterRequest,
StatusReportRequest,
GetEndpointDetailsRequest,
GetEndpointMetricsHistoryRequest,
)
from apiserver.apierrors import errors
from apiserver.service_repo import endpoint, APICall
from apiserver.bll.serving import ServingBLL, ServingStats
serving_bll = ServingBLL()
@endpoint("serving.register_container")
def register_container(call: APICall, company: str, request: RegisterRequest):
serving_bll.register_serving_container(
company_id=company, ip=call.real_ip, request=request
)
@endpoint("serving.unregister_container")
def unregister_container(_: APICall, company: str, request: UnregisterRequest):
serving_bll.unregister_serving_container(
company_id=company, container_id=request.container_id
)
@endpoint("serving.container_status_report")
def container_status_report(call: APICall, company: str, request: StatusReportRequest):
if not request.endpoint_url:
raise errors.bad_request.ValidationError(
"Missing required field 'endpoint_url'"
)
serving_bll.container_status_report(
company_id=company,
ip=call.real_ip,
report=request,
)
@endpoint("serving.get_endpoints")
def get_endpoints(call: APICall, company: str, _):
call.result.data = {"endpoints": serving_bll.get_endpoints(company)}
@endpoint("serving.get_loading_instances")
def get_loading_instances(call: APICall, company: str, _):
call.result.data = {"instances": serving_bll.get_loading_instances(company)}
@endpoint("serving.get_endpoint_details")
def get_endpoint_details(
call: APICall, company: str, request: GetEndpointDetailsRequest
):
call.result.data = serving_bll.get_endpoint_details(
company_id=company, endpoint_url=request.endpoint_url
)
@endpoint("serving.get_endpoint_metrics_history")
def get_endpoint_metrics_history(
call: APICall, company: str, request: GetEndpointMetricsHistoryRequest
):
call.result.data = ServingStats.get_endpoint_metrics(
company_id=company,
metrics_request=request,
)

View File

@@ -0,0 +1,22 @@
from apiserver.apimodels.storage import ResetSettingsRequest, SetSettingsRequest
from apiserver.bll.storage import StorageBLL
from apiserver.service_repo import endpoint, APICall
storage_bll = StorageBLL()
@endpoint("storage.get_settings")
def get_settings(call: APICall, company: str, _):
call.result.data = {"settings": storage_bll.get_company_settings(company)}
@endpoint("storage.set_settings")
def set_settings(call: APICall, company: str, request: SetSettingsRequest):
call.result.data = {"updated": storage_bll.set_company_settings(company, request)}
@endpoint("storage.reset_settings")
def reset_settings(call: APICall, company: str, request: ResetSettingsRequest):
call.result.data = {
"updated": storage_bll.reset_company_settings(company, request.keys)
}

View File

@@ -1,9 +1,9 @@
import itertools
from copy import deepcopy
from datetime import datetime
from functools import partial
from typing import Sequence, Union, Tuple
from typing import Sequence, Union, Tuple, Mapping
import attr
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
@@ -80,6 +80,11 @@ from apiserver.bll.task import (
TaskBLL,
ChangeStatusRequest,
)
from apiserver.bll.task.task_cleanup import (
delete_task_events_and_collect_urls,
schedule_for_delete,
CleanupResult,
)
from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save,
artifacts_unprepare_from_saved,
@@ -109,6 +114,7 @@ from apiserver.bll.task.utils import (
get_task_with_write_access,
)
from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
@@ -295,9 +301,7 @@ def get_types(call: APICall, company_id, request: GetTypesRequest):
}
@endpoint(
"tasks.stop", response_data_model=UpdateResponse
)
@endpoint("tasks.stop", response_data_model=UpdateResponse)
def stop(call: APICall, company_id, request: StopRequest):
"""
stop
@@ -919,6 +923,11 @@ def delete_configuration(
response_data_model=EnqueueResponse,
)
def enqueue(call: APICall, company_id, request: EnqueueRequest):
if request.verify_watched_queue and not request.update_execution_queue:
raise errors.bad_request.ValidationError(
"verify_watched_queue cannot be used with update_execution_queue=False"
)
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
@@ -928,6 +937,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
status_reason=request.status_reason,
queue_name=request.queue_name,
force=request.force,
update_execution_queue=request.update_execution_queue,
)
if request.verify_watched_queue:
res_queue = nested_get(res, ("fields", "execution.queue"))
@@ -1016,21 +1026,98 @@ def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
)
def _delete_task_events(
company_id: str,
user_id: str,
tasks: Mapping[str, CleanupResult],
delete_external_artifacts: bool,
sync_delete: bool,
):
if not tasks:
return
task_ids = list(tasks)
deleted_model_ids = set(
itertools.chain.from_iterable(
cr.deleted_model_ids for cr in tasks.values() if cr.deleted_model_ids
)
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", True
)
if delete_external_artifacts:
for t_id, cleanup_res in tasks.items():
urls = set(cleanup_res.urls.model_urls) | set(
cleanup_res.urls.artifact_urls
)
if urls:
schedule_for_delete(
task_id=t_id,
company=company_id,
user=user_id,
urls=urls,
can_delete_folders=False,
)
event_urls = delete_task_events_and_collect_urls(
company=company_id,
task_ids=task_ids,
wait_for_delete=sync_delete,
)
if deleted_model_ids:
event_urls.update(
delete_task_events_and_collect_urls(
company=company_id,
task_ids=list(deleted_model_ids),
model=True,
wait_for_delete=sync_delete,
)
)
if event_urls:
schedule_for_delete(
task_id=task_ids[0],
company=company_id,
user=user_id,
urls=event_urls,
can_delete_folders=False,
)
else:
event_bll.delete_task_events(company_id, task_ids, wait_for_delete=sync_delete)
if deleted_model_ids:
event_bll.delete_task_events(
company_id,
list(deleted_model_ids),
model=True,
wait_for_delete=sync_delete,
)
@endpoint(
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
)
def reset(call: APICall, company_id, request: ResetRequest):
task_id = request.task
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
task_id=task_id,
company_id=company_id,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
)
res = ResetResponse(**updates, **attr.asdict(cleanup_res), dequeued=dequeued)
_delete_task_events(
company_id=company_id,
user_id=call.identity.user,
tasks={task_id: cleanup_res},
delete_external_artifacts=request.delete_external_artifacts,
sync_delete=True,
)
res = ResetResponse(
**updates,
**cleanup_res.to_res_dict(request.return_file_urls),
dequeued=dequeued,
)
call.result.data_model = res
@@ -1046,25 +1133,33 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
company_id=company_id,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
succeeded = []
tasks = {}
for _id, (dequeued, cleanup, res) in results:
tasks[_id] = cleanup
succeeded.append(
ResetBatchItem(
id=_id,
dequeued=bool(dequeued.get("removed")) if dequeued else False,
**attr.asdict(cleanup),
**cleanup.to_res_dict(request.return_file_urls),
**res,
)
)
_delete_task_events(
company_id=company_id,
user_id=call.identity.user,
tasks=tasks,
delete_external_artifacts=request.delete_external_artifacts,
sync_delete=False,
)
call.result.data_model = ResetManyResponse(
succeeded=succeeded,
failed=failures,
@@ -1160,16 +1255,23 @@ def delete(call: APICall, company_id, request: DeleteRequest):
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
)
if deleted:
_delete_task_events(
company_id=company_id,
user_id=call.identity.user,
tasks={request.task: cleanup_res},
delete_external_artifacts=request.delete_external_artifacts,
sync_delete=True,
)
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
call.result.data = dict(
deleted=bool(deleted), **cleanup_res.to_res_dict(request.return_file_urls)
)
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
@@ -1181,25 +1283,42 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
succeeded = []
tasks = {}
if results:
projects = set(task.project for _, (_, task, _) in results)
projects = set()
for _id, (deleted, task, cleanup_res) in results:
if deleted:
projects.add(task.project)
tasks[_id] = cleanup_res
succeeded.append(
dict(
id=_id,
deleted=bool(deleted),
**cleanup_res.to_res_dict(request.return_file_urls),
)
)
if tasks:
_delete_task_events(
company_id=company_id,
user_id=call.identity.user,
tasks=tasks,
delete_external_artifacts=request.delete_external_artifacts,
sync_delete=False,
)
_reset_cached_tags(company_id, projects=list(projects))
call.result.data = dict(
succeeded=[
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
for _id, (deleted, _, cleanup_res) in results
],
succeeded=succeeded,
failed=failures,
)
@@ -1359,10 +1478,11 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required"
)
_assert_writable_tasks(company_id, call.identity, request.ids)
updated_projects = set(
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
tasks = _assert_writable_tasks(
company_id, call.identity, request.ids, only=("id", "project")
)
updated_projects = set(t.project for t in tasks if t.project)
project_id = project_bll.move_under_project(
entity_cls=Task,
user=call.identity.user,
@@ -1385,6 +1505,7 @@ def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
user_id=call.identity.user,
entity_cls=Task,
entity_ids=request.ids,
add_tags=request.add_tags,

View File

@@ -1,3 +1,5 @@
import unittest
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.tests.automated import TestService
@@ -236,6 +238,23 @@ class TestModelsService(TestService):
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual([], res.frameworks)
@unittest.skip(
"""This test requires the following setting
CLEARML__services__async_urls_delete__fileserver__url_prefixes=["https://files.allegro-master.hosted.allegro.ai"
Check the test results in the logs of async_delete service
"""
)
def test_delete_many_with_files(self):
models = [
self._create_model(
name=f"delete model test{idx}",
uri=f"https://files.allegro-master.hosted.allegro.ai/models/test{idx}.txt"
)
for idx in range(2)
]
self.api.models.delete_many(ids=models)
def test_make_public(self):
m1 = self._create_model(name="public model test")
@@ -277,7 +296,7 @@ class TestModelsService(TestService):
service="models",
delete_params=dict(can_fail=True, force=True),
name=kwargs.pop("name", "test"),
uri=kwargs.pop("name", "file:///a"),
uri=kwargs.pop("uri", "file:///a"),
labels=kwargs.pop("labels", {}),
**kwargs,
)

View File

@@ -5,6 +5,18 @@ from apiserver.tests.automated import TestService
class TestPipelines(TestService):
task_hyperparams = {
"properties":
{
"version": {
"section": "properties",
"name": "version",
"type": "str",
"value": "3.2"
}
}
}
def test_controller_operations(self):
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
@@ -82,14 +94,17 @@ class TestPipelines(TestService):
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
pipeline.hyperparams,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
"Args": {
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
**self.task_hyperparams,
},
)
@@ -124,6 +139,7 @@ class TestPipelines(TestService):
type="controller",
project=project,
system_tags=["pipeline"],
hyperparams=self.task_hyperparams,
),
)

View File

@@ -40,6 +40,53 @@ class TestQueues(TestService):
)
self.assertMetricQueues(res["queues"], queue_id)
def test_add_remove_clear(self):
queue1 = self._temp_queue("TestTempQueue1")
queue2 = self._temp_queue("TestTempQueue2")
task_names = ["TempDevTask1", "TempDevTask2"]
tasks = [self._temp_task(name) for name in task_names]
for task in tasks:
self.api.tasks.enqueue(task=task, queue=queue1)
# remove task with and without status update
res = self.api.queues.remove_task(task=tasks[0], queue=queue1)
self.assertEqual(res.removed, 1)
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "queued")
self.assertEqual(res.task.execution.queue, queue1)
res = self.api.queues.remove_task(task=tasks[1], queue=queue1, update_task_status=True)
self.assertEqual(res.removed, 1)
res = self.api.tasks.get_by_id(task=tasks[1])
self.assertEqual(res.task.status, "created")
res = self.api.queues.get_by_id(queue=queue1)
self.assertQueueTasks(res.queue, [])
# add task
res = self.api.queues.add_task(queue=queue2, task=tasks[0])
self.assertEqual(res.added, 1)
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "queued")
self.assertEqual(res.task.execution.queue, queue2)
res = self.api.queues.get_by_id(queue=queue2)
self.assertQueueTasks(res.queue, [tasks[0]])
# clear queue
res = self.api.queues.clear_queue(queue=queue1)
self.assertEqual(res.removed_tasks, [])
res = self.api.queues.clear_queue(queue=queue2)
self.assertEqual(res.removed_tasks, [tasks[0]])
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "created")
res = self.api.queues.get_by_id(queue=queue2)
self.assertQueueTasks(res.queue, [])
def test_hidden_queues(self):
hidden_name = "TestHiddenQueue"
hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"])
@@ -212,12 +259,19 @@ class TestQueues(TestService):
def test_get_all_ex(self):
queue_name = "TestTempQueue1"
queue_display_name = "Test display name"
queue_tags = ["Test1", "Test2"]
queue = self._temp_queue(queue_name, tags=queue_tags)
queue = self._temp_queue(queue_name, display_name=queue_display_name, tags=queue_tags)
res = self.api.queues.get_all_ex(name="TestTempQueue*").queues
self.assertQueue(
res, queue_id=queue, name=queue_name, tags=queue_tags, tasks=[], workers=[]
res,
queue_id=queue,
display_name=queue_display_name,
name=queue_name,
tags=queue_tags,
tasks=[],
workers=[],
)
tasks = [
@@ -232,6 +286,7 @@ class TestQueues(TestService):
res,
queue_id=queue,
name=queue_name,
display_name=queue_display_name,
tags=queue_tags,
tasks=tasks,
workers=workers,
@@ -259,6 +314,7 @@ class TestQueues(TestService):
queues: Sequence[AttrDict],
queue_id: str,
name: str,
display_name: str,
tags: Sequence[str],
tasks: Sequence[dict],
workers: Sequence[dict],
@@ -267,15 +323,33 @@ class TestQueues(TestService):
assert queue.last_update
self.assertEqualNoOrder(queue.tags, tags)
self.assertEqual(queue.name, name)
self.assertQueueTasks(queue, tasks)
self.assertQueueWorkers(queue, workers)
self.assertEqual(queue.display_name, display_name)
self.assertQueueTasks(queue, tasks, name, display_name)
self.assertQueueWorkers(queue, workers, name, display_name)
def assertTaskTags(self, task, system_tags):
res = self.api.tasks.get_by_id(task=task)
self.assertSequenceEqual(res.task.system_tags, system_tags)
def assertQueueTasks(self, queue: AttrDict, tasks: Sequence):
def assertQueueTasks(
self,
queue: AttrDict,
tasks: Sequence,
queue_name: str = None,
display_queue_name: str = None,
):
self.assertEqual([e.task for e in queue.entries], tasks)
if queue_name:
for task in tasks:
execution = self.api.tasks.get_by_id_ex(
id=[task["id"]],
only_fields=[
"execution.queue.name",
"execution.queue.display_name",
],
).tasks[0].execution
self.assertEqual(execution.queue.name, queue_name)
self.assertEqual(execution.queue.display_name, display_queue_name)
def assertGetNextTasks(self, queue, tasks):
for task_id in tasks:
@@ -283,11 +357,28 @@ class TestQueues(TestService):
self.assertEqual(res.entry.task, task_id)
assert not self.api.queues.get_next_task(queue=queue)
def assertQueueWorkers(self, queue: AttrDict, workers: Sequence[dict]):
def assertQueueWorkers(
self,
queue: AttrDict,
workers: Sequence[dict],
queue_name: str = None,
display_queue_name: str = None,
):
sort_key = itemgetter("name")
self.assertEqual(
sorted(queue.workers, key=sort_key), sorted(workers, key=sort_key)
)
if not workers:
return
res = self.api.workers.get_all()
worker_ids = {w["key"] for w in workers}
found = [w for w in res.workers if w.key in worker_ids]
self.assertEqual(len(found), len(worker_ids))
for worker in found:
for queue in worker.queues:
self.assertEqual(queue.name, queue_name)
self.assertEqual(queue.display_name, display_queue_name)
def _temp_queue(self, queue_name, **kwargs):
return self.create_temp("queues", name=queue_name, **kwargs)

View File

@@ -12,7 +12,7 @@ class TestReports(TestService):
def _delete_project(self, name):
existing_project = first(
self.api.projects.get_all_ex(
name=f"^{re.escape(name)}$", search_hidden=True
name=f"^{re.escape(name)}$", search_hidden=True, allow_public=False
).projects
)
if existing_project:
@@ -34,10 +34,10 @@ class TestReports(TestService):
self.assertEqual(set(task.tags), set(tags))
self.assertEqual(task.type, "report")
self.assertEqual(set(task.system_tags), {"hidden", "reports"})
projects = self.api.projects.get_all_ex(name=r"^\.reports$").projects
projects = self.api.projects.get_all_ex(name=r"^\.reports$", allow_public=False).projects
self.assertEqual(len(projects), 0)
project = self.api.projects.get_all_ex(
name=r"^\.reports$", search_hidden=True
name=r"^\.reports$", search_hidden=True, allow_public=False
).projects[0]
self.assertEqual(project.id, task.project.id)
self.assertEqual(set(project.system_tags), {"hidden", "reports"})
@@ -108,6 +108,7 @@ class TestReports(TestService):
include_stats=True,
check_own_contents=True,
search_hidden=True,
allow_public=False,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
@@ -120,6 +121,7 @@ class TestReports(TestService):
include_stats=True,
check_own_contents=True,
search_hidden=True,
allow_public=False,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]

View File

@@ -0,0 +1,124 @@
from time import time, sleep
from apiserver.apierrors import errors
from apiserver.tests.automated import TestService
class TestServing(TestService):
def test_status_report(self):
container_id1 = "container_1"
container_id2 = "container_2"
url = "http://test_url"
reference = [
{"type": "app_id", "value": "test"},
{"type": "app_instance", "value": "abd478c8"},
{"type": "model", "value": "262829d3"},
{"type": "model", "value": "7ea29c04"},
]
container_infos = [
{
"container_id": container_id, # required
"endpoint_name": "my endpoint", # required
"endpoint_url": url, # can be omitted for register but required for status report
"model_name": "my model", # required
"model_source": "s3//my_bucket", # optional right now
"model_version": "3.1.0", # optional right now
"preprocess_artifact": "some string here", # optional right now
"input_type": "another string here", # optional right now
"input_size": 9_000_000, # optional right now, bytes
"tags": ["tag1", "tag2"], # optional
"system_tags": None, # optional
**({"reference": reference} if container_id == container_id1 else {}),
}
for container_id in (container_id1, container_id2)
]
# registering instances
for container_info in container_infos:
self.api.serving.register_container(
**container_info,
timeout=100, # expiration timeout in seconds. Optional, the default value is 600
)
for idx, container_info in enumerate(container_infos):
mul = idx + 1
self.api.serving.container_status_report(
**container_info,
uptime_sec=1000 * mul,
requests_num=1000 * mul,
requests_min=5 * mul, # requests per minute
latency_ms=100 * mul, # average latency
machine_stats={ # the same structure here as used by worker status_reports
"cpu_usage": [10, 20],
"memory_used": 50 * 1024,
},
)
# getting endpoints and endpoint details
endpoints = self.api.serving.get_endpoints().endpoints
self.assertTrue(any(e for e in endpoints if e.url == url))
details = self.api.serving.get_endpoint_details(endpoint_url=url)
self.assertEqual(details.url, url)
self.assertEqual(details.uptime_sec, 2000)
self.assertEqual(
{
inst.id: [
inst[field]
for field in (
"uptime_sec",
"requests",
"requests_min",
"latency_ms",
"cpu_count",
"gpu_count",
"reference",
)
]
for inst in details.instances
},
{
"container_1": [1000, 1000, 5, 100, 2, 0, reference],
"container_2": [2000, 2000, 10, 200, 2, 0, []],
},
)
# make sure that the first call did not invalidate anything
new_details = self.api.serving.get_endpoint_details(endpoint_url=url)
self.assertEqual(details, new_details)
# charts
sleep(5) # give time to ES to accomodate data
to_date = int(time()) + 40
from_date = to_date - 100
for metric_type, title, value in (
(None, "Number of Requests", 3000),
("requests_min", "Requests per Minute", 15),
("latency_ms", "Average Latency (ms)", 150),
("cpu_count", "CPU Count", 4),
("cpu_util", "Average CPU Load (%)", 15),
("ram_used", "RAM Used (GB)", 100.0),
):
res = self.api.serving.get_endpoint_metrics_history(
endpoint_url=url,
from_date=from_date,
to_date=to_date,
interval=1,
**({"metric_type": metric_type} if metric_type else {}),
)
self.assertEqual(res.computed_interval, 40)
self.assertEqual(res.total.title, title)
length = len(res.total.dates)
self.assertTrue(3 >= length >= 1)
self.assertEqual(len(res.total["values"]), length)
self.assertIn(value, res.total["values"])
self.assertEqual(set(res.instances), {container_id1, container_id2})
for inst in res.instances.values():
self.assertEqual(inst.dates, res.total.dates)
self.assertEqual(len(inst["values"]), length)
# unregistering containers
for container_id in (container_id1, container_id2):
self.api.serving.unregister_container(container_id=container_id)
endpoints = self.api.serving.get_endpoints().endpoints
self.assertFalse(any(e for e in endpoints if e.url == url))
with self.api.raises(errors.bad_request.NoContainersForUrl):
self.api.serving.get_endpoint_details(endpoint_url=url)

View File

@@ -15,7 +15,7 @@ class TestSubProjects(TestService):
def test_dataset_stats(self):
project = self._temp_project(name="Dataset test", system_tags=["dataset"])
res = self.api.organization.get_entities_count(
datasets={"system_tags": ["dataset"]}
datasets={"system_tags": ["dataset"]}, allow_public=False,
)
self.assertEqual(res.datasets, 1)
@@ -439,6 +439,15 @@ class TestSubProjects(TestService):
self.assertEqual(res2.own_tasks, 0)
self.assertEqual(res2.own_models, 0)
def test_public_names_clash(self):
# cannot create a project with a name that match public existing project
with self.api.raises(errors.bad_request.PublicProjectExists):
project = self._temp_project(name="ClearML Examples")
# cannot create a subproject under a public project
with self.api.raises(errors.bad_request.PublicProjectExists):
project = self._temp_project(name="ClearML Examples/my project")
def test_get_all_with_stats(self):
project4, _ = self._temp_project_with_tasks(name="project1/project3/project4")
project5, _ = self._temp_project_with_tasks(name="project1/project3/project5")

View File

@@ -217,7 +217,10 @@ class TestTaskEvents(TestService):
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
self.assertEqual(0, metric_data.first_value_iteration)
self.assertEqual(0, metric_data.first_value)
self.assertEqual(iter_count, metric_data.count)
self.assertEqual(sum(i for i in range(iter_count)) / iter_count, metric_data.mean_value)
res = self.api.events.get_task_latest_scalar_values(task=task)
self.assertEqual(iter_count - 1, res.last_iter)
@@ -243,6 +246,7 @@ class TestTaskEvents(TestService):
"variant": f"Variant{variant_idx}",
"value": iteration,
"model_event": True,
"x_axis_label": f"Label_{metric_idx}_{variant_idx}"
}
for iteration in range(2)
for metric_idx in range(5)
@@ -271,6 +275,7 @@ class TestTaskEvents(TestService):
variant_data = metric_data.Variant0
self.assertEqual(variant_data.x, [0, 1])
self.assertEqual(variant_data.y, [0.0, 1.0])
self.assertEqual(variant_data.x_axis_label, "Label_0_0")
model_data = self.api.models.get_all_ex(
id=[model], only_fields=["last_metrics", "last_iteration"]
@@ -282,6 +287,7 @@ class TestTaskEvents(TestService):
self.assertEqual(1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
self.assertEqual("Label_4_4", metric_data.x_axis_label)
self._assert_log_events(task=task, expected_total=1)

View File

@@ -59,7 +59,7 @@ class TestTasksResetDelete(TestService):
event_urls.update(self.send_model_events(model))
res = self.assert_delete_task(task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls)
self.assertFalse(set(res.urls.event_urls)) # event urls are not returned anymore
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
def test_reset(self):
@@ -84,7 +84,7 @@ class TestTasksResetDelete(TestService):
) = self.create_task_with_data()
res = self.api.tasks.reset(task=task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls)
self.assertFalse(res.urls.event_urls) # event urls are not returned anymore
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
def test_model_delete(self):
@@ -124,7 +124,7 @@ class TestTasksResetDelete(TestService):
self.assertEqual(res.disassociated_tasks, 0)
self.assertEqual(res.deleted_tasks, 1)
self.assertEqual(res.deleted_models, 2)
self.assertEqual(set(res.urls.event_urls), event_urls)
self.assertFalse(set(res.urls.event_urls)) # event urls are not returned anymore
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
with self.api.raises(errors.bad_request.InvalidTaskId):
self.api.tasks.get_by_id(task=task)

View File

@@ -71,6 +71,16 @@ class TestTasksFiltering(TestService):
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
# _any_/_all_ queries
res = self.api.tasks.get_all_ex(
**{"_any_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_all_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
# simplified range syntax
res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
@@ -80,6 +90,15 @@ class TestTasksFiltering(TestService):
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_any_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_all_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
def test_range_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()

View File

@@ -116,7 +116,7 @@ class TestWorkersService(TestService):
if w == workers[0]:
data["task"] = task_id
self.api.workers.status_report(**data)
timestamp += 1000
timestamp += 60*1000
return workers
@@ -151,7 +151,7 @@ class TestWorkersService(TestService):
time.sleep(5) # give to ES time to refresh
from_date = start
to_date = start + 10
to_date = start + 40*10
# no variants
res = self.api.workers.get_stats(
items=[
@@ -180,7 +180,7 @@ class TestWorkersService(TestService):
self.assertEqual(
set(stat.aggregation for stat in metric.stats), metric_stats
)
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
self.assertTrue(11 >= len(metric.dates) >= 10)
# split by variants
res = self.api.workers.get_stats(
@@ -199,7 +199,7 @@ class TestWorkersService(TestService):
set(metric.variant for metric in worker.metrics),
{"0", "1"} if worker.worker == workers[0] else {"0"},
)
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
self.assertTrue(11 >= len(metric.dates) >= 10)
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
@@ -216,25 +216,25 @@ class TestWorkersService(TestService):
def test_get_activity_report(self):
# test no workers data
# run on an empty es db since we have no way
# to pass non existing workers to this api
# to pass non-existing workers to this api
# res = self.api.workers.get_activity_report(
# from_timestamp=from_timestamp.timestamp(),
# to_timestamp=to_timestamp.timestamp(),
# interval=20,
# )
start = int(time.time())
self._simulate_workers(int(time.time()))
self._simulate_workers(start)
time.sleep(5) # give to es time to refresh
# no variants
res = self.api.workers.get_activity_report(
from_date=start, to_date=start + 10, interval=2
from_date=start, to_date=start + 10*40, interval=2
)
self.assertWorkerSeries(res["total"], 2, 5)
self.assertWorkerSeries(res["active"], 1, 5)
self.assertWorkerSeries(res["total"], 2, 10)
self.assertWorkerSeries(res["active"], 1, 10)
def assertWorkerSeries(self, series_data: dict, count: int, size: int):
self.assertEqual(len(series_data["dates"]), size)
self.assertEqual(len(series_data["counts"]), size)
self.assertTrue(any(c == count for c in series_data["counts"]))
self.assertTrue(all(c <= count for c in series_data["counts"]))
# self.assertTrue(any(c == count for c in series_data["counts"]))
# self.assertTrue(all(c <= count for c in series_data["counts"]))

View File

@@ -1,4 +1,4 @@
from typing import Sequence, Tuple, Any, Union, Callable, Optional, Mapping
from typing import Sequence, Tuple, Any, Union, Callable, Optional, Protocol
def flatten_nested_items(
@@ -35,8 +35,13 @@ def deep_merge(source: dict, override: dict) -> dict:
return source
class GetItem(Protocol):
def __getitem__(self, key: Any) -> Any:
pass
def nested_get(
dictionary: Mapping,
dictionary: GetItem,
path: Sequence[str],
default: Optional[Union[Any, Callable]] = None,
) -> Any:

View File

@@ -1 +1 @@
__version__ = "1.16.0"
__version__ = "2.0.0"

View File

@@ -58,7 +58,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
image: elasticsearch:8.17.0
restart: unless-stopped
volumes:
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -87,7 +87,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:4.4.29
image: mongo:6.0.19
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -98,7 +98,7 @@ services:
networks:
- backend
container_name: clearml-redis
image: redis:6.2
image: redis:7.4.1
restart: unless-stopped
volumes:
- c:/opt/clearml/data/redis:/data

View File

@@ -60,7 +60,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
image: elasticsearch:8.17.0
restart: unless-stopped
volumes:
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -88,7 +88,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:4.4.29
image: mongo:6.0.19
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -99,7 +99,7 @@ services:
networks:
- backend
container_name: clearml-redis
image: redis:6.2
image: redis:7.4.1
restart: unless-stopped
volumes:
- /opt/clearml/data/redis:/data

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2019 allegro.ai, Inc.
Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -7,6 +7,7 @@ import urllib.parse
from argparse import ArgumentParser
from collections import defaultdict
from pathlib import Path
from typing import Optional
from boltons.iterutils import first
from flask import Flask, request, send_from_directory, abort, Response
@@ -113,8 +114,12 @@ def download(path):
return response
def _get_full_path(path: str) -> Path:
return Path(safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path)))
def _get_full_path(path: str) -> Optional[Path]:
path_str = safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path))
if path_str is None:
return path_str
return Path(path_str)
@app.route("/<path:path>", methods=["DELETE"])
@@ -123,7 +128,7 @@ def delete(path):
auth_handler.validate(request)
full_path = _get_full_path(path)
if not full_path.exists() or not full_path.is_file():
if not (full_path and full_path.exists() and full_path.is_file()):
log.error(f"Error deleting file {str(full_path)}. Not found or not a file")
abort(Response(f"File {str(path)} not found", 404))
@@ -161,7 +166,7 @@ def batch_delete():
full_path = _get_full_path(path)
if not full_path.exists():
if not (full_path and full_path.exists()):
record_error("Not found", file, path)
continue

View File

@@ -5,7 +5,7 @@ flask-cors>=3.0.5
flask>=2.3.3
gunicorn>=20.1.0
pyhocon>=0.3.35
redis>=4.5.4,<5
redis==5.2.1
setuptools>=65.5.1
urllib3>=1.26.18
werkzeug>=3.0.1

View File

@@ -0,0 +1,158 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: clearml-apiserver
image: allegroai/clearml:1.17.1-554
restart: unless-stopped
volumes:
- c:/opt/clearml/logs:/var/log/clearml
- c:/opt/clearml/config:/opt/clearml/config
- c:/opt/clearml/data/fileserver:/mnt/fileserver
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
CLEARML_SERVER_DEPLOYMENT_TYPE: win10
CLEARML__apiserver__pre_populate__enabled: "true"
CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
CLEARML__services__async_urls_delete__enabled: "true"
CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
ports:
- "8008:8008"
networks:
- backend
- frontend
elasticsearch:
networks:
- backend
container_name: clearml-elastic
environment:
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.type: "single-node"
http.compression_level: "7"
node.name: clearml
reindex.remote.whitelist: "'*.*'"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:8.15.3
restart: unless-stopped
volumes:
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
- /usr/share/elasticsearch/logs
fileserver:
networks:
- backend
- frontend
command:
- fileserver
container_name: clearml-fileserver
image: allegroai/clearml:1.17.1-554
environment:
CLEARML__fileserver__delete__allow_batch: "true"
restart: unless-stopped
volumes:
- c:/opt/clearml/logs:/var/log/clearml
- c:/opt/clearml/data/fileserver:/mnt/fileserver
- c:/opt/clearml/config:/opt/clearml/config
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: clearml-mongo
image: mongo:5.0.26
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
- c:/opt/clearml/data/mongo_4/db:/data/db
- c:/opt/clearml/data/mongo_4/configdb:/data/configdb
redis:
networks:
- backend
container_name: clearml-redis
image: redis:6.2
restart: unless-stopped
volumes:
- c:/opt/clearml/data/redis:/data
webserver:
command:
- webserver
container_name: clearml-webserver
image: allegroai/clearml:1.17.1-554
restart: unless-stopped
volumes:
- c:/clearml/logs:/var/log/clearml
depends_on:
- apiserver
ports:
- "8080:80"
networks:
- backend
- frontend
async_delete:
depends_on:
- apiserver
- redis
- mongo
- elasticsearch
- fileserver
container_name: async_delete
image: allegroai/clearml:1.17.1-554
networks:
- backend
restart: unless-stopped
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
PYTHONPATH: /opt/clearml/apiserver
CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
entrypoint:
- python3
- -m
- jobs.async_urls_delete
- --fileserver-host
- http://fileserver:8081
volumes:
- c:/opt/clearml/logs:/var/log/clearml
- c:/opt/clearml/config:/opt/clearml/config
networks:
backend:
driver: bridge
frontend:
name: frontend
driver: bridge

View File

@@ -0,0 +1,195 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: clearml-apiserver
image: allegroai/clearml:1.17.1-554
restart: unless-stopped
volumes:
- /opt/clearml/logs:/var/log/clearml
- /opt/clearml/config:/opt/clearml/config
- /opt/clearml/data/fileserver:/mnt/fileserver
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
CLEARML_SERVER_DEPLOYMENT_TYPE: linux
CLEARML__apiserver__pre_populate__enabled: "true"
CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
CLEARML__services__async_urls_delete__enabled: "true"
CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
CLEARML__secure__credentials__services_agent__user_key: ${CLEARML_AGENT_ACCESS_KEY:-}
CLEARML__secure__credentials__services_agent__user_secret: ${CLEARML_AGENT_SECRET_KEY:-}
ports:
- "8008:8008"
networks:
- backend
- frontend
elasticsearch:
networks:
- backend
container_name: clearml-elastic
environment:
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.type: "single-node"
http.compression_level: "7"
node.name: clearml
reindex.remote.whitelist: "'*.*'"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:8.15.3
restart: unless-stopped
volumes:
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
- /usr/share/elasticsearch/logs
fileserver:
networks:
- backend
- frontend
command:
- fileserver
container_name: clearml-fileserver
image: allegroai/clearml:1.17.1-554
environment:
CLEARML__fileserver__delete__allow_batch: "true"
restart: unless-stopped
volumes:
- /opt/clearml/logs:/var/log/clearml
- /opt/clearml/data/fileserver:/mnt/fileserver
- /opt/clearml/config:/opt/clearml/config
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: clearml-mongo
image: mongo:5.0.26
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
- /opt/clearml/data/mongo_4/db:/data/db
- /opt/clearml/data/mongo_4/configdb:/data/configdb
redis:
networks:
- backend
container_name: clearml-redis
image: redis:6.2
restart: unless-stopped
volumes:
- /opt/clearml/data/redis:/data
webserver:
command:
- webserver
container_name: clearml-webserver
# environment:
# CLEARML_SERVER_SUB_PATH : clearml-web # Allow Clearml to be served with a URL path prefix.
image: allegroai/clearml:1.17.1-554
restart: unless-stopped
depends_on:
- apiserver
ports:
- "8080:80"
networks:
- backend
- frontend
async_delete:
depends_on:
- apiserver
- redis
- mongo
- elasticsearch
- fileserver
container_name: async_delete
image: allegroai/clearml:1.17.1-554
networks:
- backend
restart: unless-stopped
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
PYTHONPATH: /opt/clearml/apiserver
CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
entrypoint:
- python3
- -m
- jobs.async_urls_delete
- --fileserver-host
- http://fileserver:8081
volumes:
- /opt/clearml/logs:/var/log/clearml
- /opt/clearml/config:/opt/clearml/config
agent-services:
networks:
- backend
container_name: clearml-agent-services
image: allegroai/clearml-agent-services:latest
deploy:
restart_policy:
condition: on-failure
privileged: true
environment:
CLEARML_HOST_IP: ${CLEARML_HOST_IP}
CLEARML_WEB_HOST: ${CLEARML_WEB_HOST:-}
CLEARML_API_HOST: http://apiserver:8008
CLEARML_FILES_HOST: ${CLEARML_FILES_HOST:-}
CLEARML_API_ACCESS_KEY: ${CLEARML_AGENT_ACCESS_KEY:-$CLEARML_API_ACCESS_KEY}
CLEARML_API_SECRET_KEY: ${CLEARML_AGENT_SECRET_KEY:-$CLEARML_API_SECRET_KEY}
CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}
CLEARML_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
CLEARML_WORKER_ID: "clearml-services"
CLEARML_AGENT_DOCKER_HOST_MOUNT: "/opt/clearml/agent:/root/.clearml"
SHUTDOWN_IF_NO_ACCESS_KEY: 1
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/clearml/agent:/root/.clearml
depends_on:
- apiserver
entrypoint: >
bash -c "curl --retry 10 --retry-delay 10 --retry-connrefused 'http://apiserver:8008/debug.ping' && /usr/agent/entrypoint.sh"
networks:
backend:
driver: bridge
frontend:
driver: bridge