Compare commits

59 Commits

Author SHA1 Message Date
allegroai
5de7c12062 Version bump to v1.6 2022-07-08 18:05:43 +03:00
allegroai
3f79c19079 Add v1.6.0 mongodb migration 2022-07-08 18:05:33 +03:00
allegroai
fe29743c54 Add support for new IDs generation when importing projects 2022-07-08 18:04:40 +03:00
allegroai
d760cf5835 Remove use of dpath in query projection 2022-07-08 18:04:02 +03:00
allegroai
3695f25a5f Fix internal error returned to clients 2022-07-08 18:03:38 +03:00
allegroai
c6f1beafdd Update API version to 2.20 2022-07-08 18:02:44 +03:00
allegroai
68a54c34f3 Add user creation time to users.get_current_user 2022-07-08 17:59:45 +03:00
allegroai
ab495ae586 Fix archived projects handling 2022-07-08 17:55:02 +03:00
allegroai
b058770af1 Fix handling of empty hyperparam/configuration keys 2022-07-08 17:54:19 +03:00
allegroai
f7e833bf6f Fix loading expired worker entries from Redis 2022-07-08 17:53:12 +03:00
allegroai
36b9ab0453 Fix handling of empty keys in query 2022-07-08 17:51:49 +03:00
allegroai
ec0436d0da Fix task cleanup 2022-07-08 17:50:49 +03:00
allegroai
0f6c4e75b7 Fix debug images URL handling and task routing 2022-07-08 17:50:26 +03:00
allegroai
a41ae112a1 Fix backward compatibility when importing old projects 2022-07-08 17:49:36 +03:00
allegroai
c28f478ea8 Fix worker Id is used instead of worker key when processing report 2022-07-08 17:48:17 +03:00
allegroai
c18eb99d06 Return getting_started_info in users.get_current_user 2022-07-08 17:45:33 +03:00
allegroai
3a60f00d93 Add support for Dataset projects 2022-07-08 17:45:03 +03:00
allegroai
ee87778548 Better support for PyJWT 2.4 2022-07-08 17:44:17 +03:00
allegroai
52c0c4d438 Add model task cleanup on tasks.reset 2022-07-08 17:43:54 +03:00
allegroai
d117a4f022 Support max_task_entries option in queues.get_by_id/get_all
Add queues.peek_task and queues.get_num_entries
2022-07-08 17:42:20 +03:00
allegroai
6683d2d7a9 Fix task cleanup 2022-07-08 17:40:55 +03:00
allegroai
05357fe25e Support publish option in tasks.completed 2022-07-08 17:40:43 +03:00
allegroai
adc1825843 Add support for model statistics 2022-07-08 17:39:41 +03:00
allegroai
0c15169668 Improve tests 2022-07-08 17:38:31 +03:00
allegroai
123dc1dcfb Improve query error handling 2022-07-08 17:38:22 +03:00
allegroai
b2feafac09 Support workers filtering with tags 2022-07-08 17:37:33 +03:00
allegroai
b41ab8c550 Better support for queue metrics and queue metrics refresh on sample 2022-07-08 17:36:46 +03:00
allegroai
62d5779bd5 Count own tasks/models for projects 2022-07-08 17:35:01 +03:00
allegroai
f8b9d9802e Add support for organization.get_entities_count 2022-07-08 17:32:56 +03:00
allegroai
dd8a1503b0 Add support for navigate_current_metric in events.get_debug_image_sample 2022-07-08 17:31:44 +03:00
allegroai
cff98ae900 Add support for events.get_task_single_value_metrics, events.plots, events.get_plot_sample and events.next_plot_sample 2022-07-08 17:29:39 +03:00
allegroai
9b108740da Bump PyJWT version due to "Key confusion through non-blocklisted public key formats" vulnerability 2022-05-25 16:50:19 +03:00
allegroai
08a7bc7c9f Fix not all the event logs are returned from sharded ES 2022-05-20 15:11:05 +03:00
allegroai
fb256d7e5b Version bump to v1.5 2022-05-18 15:29:45 +03:00
allegroai
710443b078 Fix move task to trash is not thread-safe 2022-05-18 10:31:20 +03:00
allegroai
e0cde2f7c9 Add support for deleting pipeline projects 2022-05-18 10:30:21 +03:00
allegroai
60b9c8de14 Allow arbitrary task fields in project statistics filter 2022-05-18 10:29:36 +03:00
allegroai
ecffe26be4 Fix auth.edit_credentials 2022-05-18 10:28:58 +03:00
allegroai
2570bd9e26 Fix ES issue with capital letters in index name 2022-05-18 10:18:23 +03:00
allegroai
174f84514a Fix no destination when merging projects 2022-05-18 10:17:34 +03:00
allegroai
65cb8d7b43 Refactor method name 2022-05-18 10:16:41 +03:00
allegroai
5f8ef808a3 Fix ES issue with capital letters in index name 2022-05-18 10:16:19 +03:00
allegroai
4941ac70e0 Add events.clear_task_log 2022-05-17 16:09:23 +03:00
allegroai
67cd461145 Add auth.edit_credentials 2022-05-17 16:08:12 +03:00
allegroai
92b5fc6f9a Fix handling hidden sub-projects 2022-05-17 16:06:34 +03:00
allegroai
b90165b4e4 Support queue_name in tasks enqueue 2022-05-17 16:04:34 +03:00
allegroai
6c2dcb5c8a Improve error message 2022-05-17 15:56:18 +03:00
allegroai
3efed32934 Add X-Jwt-Payload to redacted headers 2022-05-17 15:55:41 +03:00
allegroai
69737308fe Version bump to v1.4.0 2022-04-18 16:38:22 +03:00
allegroai
a6dbea808a Add indices for task.last_update and task.status_changed 2022-04-18 16:37:22 +03:00
allegroai
5131b17901 Support not returning hidden sub-projects when include_stats is specified without search_hidden 2022-04-18 16:36:14 +03:00
allegroai
5f21c3a56d Add support for searching hidden projects and tasks 2022-04-18 16:34:18 +03:00
allegroai
2350ac64ed Fix internal error on count task events if there is no events index 2022-04-18 16:31:02 +03:00
allegroai
d146127c18 Add events.clear_scroll endpoint to clear event search scrolls 2022-04-18 16:29:57 +03:00
Mal Miller
abd65e103e Ensure agent-services waits for API server to be ready (#129) 2022-03-31 11:10:45 +03:00
pollfly
bf65ea7bd0 Resize admonitions (#126) 2022-03-27 15:04:43 +03:00
pollfly
73e278a8ed Add deprecation notes to legacy docs (#124) 2022-03-23 23:51:55 +02:00
Zied ANDOLSI
d92dfbbdb7 Allow ClearML to be served with a URL path prefix (#121)
* add server root url

* [Feature Request] Add proxy_pass for root url other than /

* [Feature Request] Add proxy_pass for root url other than /

* add support for web sub path

* add support for web sub path

* use default conf instead of created a custom one

* code reivew: move cp command in if block

* Add commented env var in the docker-compose file

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-22 17:21:58 +02:00
Zied ANDOLSI
5c1e419eb5 Allow overriding clearml web git url on build (#122)
* add server root url

* [Feature Request] Add possibility to override clearml web git url

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-17 14:35:50 +02:00
87 changed files with 3654 additions and 1246 deletions

View File

@@ -26,6 +26,9 @@
23: ["invalid_domain_name", "malformed domain name"]
24: ["not_public_object", "object is not public"]
# Auth / Login
75: ["invalid_access_key", "access key not found for user"]
# Tasks
100: ["task_error", "general task error"]
101: ["invalid_task_id", "invalid task id"]
@@ -86,7 +89,7 @@
# Database
800: ["data_validation_error", "data validation error"]
801: ["expected_unique_data", "value combination already exists"]
801: ["expected_unique_data", "value combination already exists (unique field already contains this value)"]
# Workers
1001: ["invalid_worker_id", "invalid worker id"]

View File

@@ -96,6 +96,11 @@ class GetCredentialsResponse(Base):
credentials = ListField(CredentialsResponse)
class EditCredentialsRequest(Base):
access_key = StringField(required=True)
label = StringField()
class RevokeCredentialsRequest(Base):
access_key = StringField(required=True)

View File

@@ -48,7 +48,7 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
)
@@ -64,13 +64,14 @@ class TaskMetricVariant(Base):
variant: str = StringField(required=True)
class GetDebugImageSampleRequest(TaskMetricVariant):
class GetHistorySampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
class NextDebugImageSampleRequest(Base):
class NextHistorySampleRequest(Base):
task: str = StringField(required=True)
scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True)
@@ -119,15 +120,22 @@ class MetricEvents(Base):
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
class DebugImageResponse(Base):
class MetricEventsResponse(Base):
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
scroll_id: str = StringField()
class TaskMetricsRequest(Base):
class MultiTasksRequestBase(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
@@ -137,3 +145,13 @@ class TaskPlotsRequest(Base):
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ClearScrollRequest(Base):
scroll_id: str = StringField()
class ClearTaskLogRequest(Base):
task: str = StringField(required=True)
threshold_sec = IntField()
allow_locked = BoolField(default=False)

View File

@@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata):
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True)
class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)

View File

@@ -1,5 +1,7 @@
from jsonmodels import fields, models
from apiserver.apimodels import DictField
class Filter(models.Base):
tags = fields.ListField([str])
@@ -9,3 +11,11 @@ class Filter(models.Base):
class TagsRequest(models.Base):
include_system = fields.BoolField(default=False)
filter = fields.EmbeddedField(Filter)
class EntitiesCountRequest(models.Base):
projects = DictField()
tasks = DictField()
models = DictField()
pipelines = DictField()
datasets = DictField()

View File

@@ -57,6 +57,7 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest):
class ProjectsGetRequest(models.Base):
include_dataset_stats = fields.BoolField(default=False)
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
@@ -65,3 +66,4 @@ class ProjectsGetRequest(models.Base):
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)

View File

@@ -26,6 +26,10 @@ class QueueRequest(Base):
queue = StringField(required=True)
class GetByIdRequest(QueueRequest):
max_task_entries = IntField()
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
@@ -59,6 +63,7 @@ class GetMetricsRequest(Base):
from_date = FloatField(required=True, validators=validators.Min(0))
to_date = FloatField(required=True, validators=validators.Min(0))
interval = IntField(required=True, validators=validators.Min(1))
refresh = BoolField(default=False)
class QueueMetrics(Base):

View File

@@ -96,6 +96,7 @@ class UpdateRequest(TaskUpdateRequest):
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
class DeleteRequest(UpdateRequest):
@@ -108,6 +109,14 @@ class SetRequirementsRequest(TaskRequest):
requirements = DictField(required=True)
class CompletedRequest(UpdateRequest):
publish = BoolField(default=False)
class CompletedResponse(UpdateResponse):
published = IntField(default=0)
class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)
@@ -262,6 +271,7 @@ class StopManyRequest(TaskBatchRequest):
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
queue_name = StringField()
validate_tasks = BoolField(default=False)

View File

@@ -96,6 +96,7 @@ class WorkerResponseEntry(WorkerEntry):
class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
class GetAllResponse(Base):

View File

@@ -64,7 +64,7 @@ class AuthBLL:
feature_set="basic",
)
return GetTokenResponse(token=token.decode("ascii"))
return GetTokenResponse(token=token)
@staticmethod
def create_user(request: CreateUserRequest, call: APICall = None) -> str:

View File

@@ -8,28 +8,31 @@ from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
from elasticsearch import helpers
import elasticsearch
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
get_index_name,
check_empty_data,
search_company_events,
delete_company_events,
MetricVariants,
get_metric_variants_condition,
uncompress_plot,
get_max_metric_and_variant_counts,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
from apiserver.bll.event.history_plot_iterator import HistoryPlotIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
@@ -38,7 +41,7 @@ from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items
from apiserver.utilities.dicts import flatten_nested_items, nested_get
from apiserver.utilities.json import loads
# noinspection PyTypeChecker
@@ -48,6 +51,9 @@ MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
log = config.logger(__file__)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
@@ -66,13 +72,19 @@ class EventBLL(object):
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
self.redis = redis or redman.connection("apiserver")
self._metrics = EventMetrics(self.es)
self._skip_iteration_for_metric = set(
config.get("services.events.ignore_iteration.metrics", [])
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.debug_images_iterator = MetricDebugImagesIterator(
es=self.es, redis=self.redis
)
self.debug_image_sample_history = HistoryDebugImageIterator(
es=self.es, redis=self.redis
)
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis)
self.events_iterator = EventsIterator(es=self.es)
@property
@@ -219,7 +231,7 @@ class EventBLL(object):
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
@@ -304,11 +316,7 @@ class EventBLL(object):
@parallel_chunked_decorator(chunk_size=10)
def uncompress_plots(self, plot_events: Sequence[dict]):
for event in plot_events:
plot_data = event.pop(PlotFields.plot_data, None)
if plot_data and event.get(PlotFields.plot_str) is None:
event[PlotFields.plot_str] = zlib.decompress(
base64.b64decode(plot_data)
).decode()
uncompress_plot(event)
@staticmethod
def _is_valid_json(text: str) -> bool:
@@ -476,6 +484,13 @@ class EventBLL(object):
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // num_last_iterations)
es_req: dict = {
"size": 0,
@@ -483,14 +498,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
@@ -512,9 +527,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
return []
@@ -635,6 +648,42 @@ class EventBLL(object):
return events, total_events, next_scroll_id
def get_debug_image_urls(
self, company_id: str, task_id: str, after_key: dict = None
) -> Tuple[Sequence[str], Optional[dict]]:
if check_empty_data(self.es, company_id, EventType.metrics_image):
return [], None
es_req = {
"size": 0,
"aggs": {
"debug_images": {
"composite": {
"size": 1000,
**({"after": after_key} if after_key else {}),
"sources": [{"url": {"terms": {"field": "url"}}}],
}
}
},
"query": {
"bool": {
"must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
}
},
}
es_response = search_company_events(
self.es,
company_id=company_id,
event_type=EventType.metrics_image,
body=es_req,
)
res = nested_get(es_response, ("aggregations", "debug_images"))
if not res:
return [], None
return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
@@ -760,20 +809,26 @@ class EventBLL(object):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
es_req = {
"size": 0,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
}
}
@@ -786,9 +841,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
es_res = search_company_events(body=es_req, **search_args)
metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
@@ -814,6 +867,12 @@ class EventBLL(object):
]
}
}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
es_req = {
"size": 0,
"query": query,
@@ -821,14 +880,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
@@ -859,9 +918,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
es_res = search_company_events(body=es_req, **search_args)
metrics = []
max_timestamp = 0
@@ -962,18 +1019,23 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
@staticmethod
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
@@ -987,6 +1049,49 @@ class EventBLL(object):
return es_res.get("deleted", 0)
def clear_task_log(
self,
company_id: str,
task_id: str,
allow_locked: bool = False,
threshold_sec: int = None,
):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.task_log
):
return 0
with translate_errors_context(), TimingContext("es", "clear_task_log"):
must = [{"term": {"task": task_id}}]
sort = None
if threshold_sec:
timestamp_ms = int(threshold_sec * 1000)
must.append(
{
"range": {
"timestamp": {
"lt": (es_factory.get_timestamp_millis() - timestamp_ms)
}
}
}
)
sort = {"timestamp": {"order": "desc"}}
es_req = {
"query": {"bool": {"must": must}},
**({"sort": sort} if sort else {}),
}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.task_log,
body=es_req,
refresh=True,
)
return es_res.get("deleted", 0)
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
"""
Delete mutliple task events. No check is done for tasks write access
@@ -1005,3 +1110,16 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:
return
# noinspection PyBroadException
try:
self.es.clear_scroll(scroll_id=scroll_id)
except elasticsearch.exceptions.NotFoundError:
pass
except elasticsearch.exceptions.RequestError:
pass
except Exception as ex:
log.exception("Failed clearing scroll %s", scroll_id)

View File

@@ -1,10 +1,15 @@
import base64
import zlib
from enum import Enum
from typing import Union, Sequence, Mapping
from typing import Union, Sequence, Mapping, Tuple
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
class EventType(Enum):
@@ -16,10 +21,13 @@ class EventType(Enum):
all = "*"
SINGLE_SCALAR_ITERATION = -2**31
MetricVariants = Mapping[str, Sequence[str]]
class EventSettings:
_max_es_allowed_aggregation_buckets = 10000
@classproperty
def max_workers(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
@@ -31,17 +39,23 @@ class EventSettings:
)
@classproperty
def max_metrics_count(self):
return config.get("services.events.events_retrieval.max_metrics_count", 100)
@classproperty
def max_variants_count(self):
return config.get("services.events.events_retrieval.max_variants_count", 100)
def max_es_buckets(self):
percentage = (
min(
100,
config.get(
"services.events.events_retrieval.dynamic_metrics_count_threshold",
80,
),
)
/ 100
)
return int(self._max_es_allowed_aggregation_buckets * percentage)
def get_index_name(company_id: str, event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
return f"events-{event_type}-{company_id.lower()}"
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
@@ -78,6 +92,46 @@ def count_company_events(
return es.count(index=es_index, body=body, **kwargs)
def get_max_metric_and_variant_counts(
es: Elasticsearch,
company_id: Union[str, Sequence[str]],
event_type: EventType,
query: dict,
**kwargs,
) -> Tuple[int, int]:
dynamic = config.get(
"services.events.events_retrieval.dynamic_metrics_count", False
)
max_metrics_count = config.get(
"services.events.events_retrieval.max_metrics_count", 100
)
max_variants_count = config.get(
"services.events.events_retrieval.max_variants_count", 100
)
if not dynamic:
return max_metrics_count, max_variants_count
es_req: dict = {
"size": 0,
"query": query,
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
}
with translate_errors_context(), TimingContext(
"es", "get_max_metric_and_variant_counts"
):
es_res = search_company_events(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)
metrics_count = safe_get(
es_res, "aggregations/metrics_count/value", max_metrics_count
)
if not metrics_count:
return max_metrics_count, max_variants_count
return metrics_count, int(EventSettings.max_es_buckets / metrics_count)
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
conditions = [
{
@@ -94,3 +148,19 @@ def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
]
return {"bool": {"should": conditions}}
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
def uncompress_plot(event: dict):
plot_data = event.pop(PlotFields.plot_data, None)
if plot_data and event.get(PlotFields.plot_str) is None:
event[PlotFields.plot_str] = zlib.decompress(
base64.b64decode(plot_data)
).decode()

View File

@@ -4,8 +4,9 @@ from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Mapping
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from mongoengine import Q
@@ -17,6 +18,8 @@ from apiserver.bll.event.event_common import (
check_empty_data,
MetricVariants,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
SINGLE_SCALAR_ITERATION,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
@@ -166,6 +169,57 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
):
return {}
with TimingContext("es", "get_task_single_value_metrics"):
task_events = self._get_task_single_value_metrics(company_id, task_ids)
def _get_value(event: dict):
return {
field: event.get(field)
for field in ("metric", "variant", "value", "timestamp")
}
return {
task: [_get_value(e) for e in events]
for task, events in bucketize(task_events, itemgetter("task")).items()
}
def _get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
) -> Sequence[dict]:
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
}
with translate_errors_context():
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=EventType.metrics_scalar,
)
if not es_res["hits"]["total"]["value"]:
return []
return [hit["_source"] for hit in es_res["hits"]["hits"]]
MetricInterval = Tuple[str, str, int, int]
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
@@ -219,11 +273,17 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = [{"term": {"task": task_id}}]
must = self._task_conditions(task_id)
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
@@ -231,14 +291,14 @@ class EventMetrics:
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
@@ -253,9 +313,7 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
@@ -307,33 +365,42 @@ class EventMetrics:
"""
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": aggregation,
}
},
}
}
aggs_result = self._query_aggregation_for_task_metrics(
company_id=company_id,
event_type=event_type,
aggs=aggs,
task_id=task_id,
metrics=metrics,
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": aggregation,
}
},
}
},
}
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
@@ -360,19 +427,18 @@ class EventMetrics:
for key, value in aggregation.items()
}
def _query_aggregation_for_task_metrics(
self,
company_id: str,
event_type: EventType,
aggs: dict,
task_id: str,
metrics: Sequence[Tuple[str, str]],
) -> dict:
"""
Return the result of elastic search query for the given aggregation filtered
by the given task_ids and metrics
"""
must = [{"term": {"task": task_id}}]
@staticmethod
def _task_conditions(task_id: str) -> list:
return [
{"term": {"task": task_id}},
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
]
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:
should = [
{
@@ -387,20 +453,9 @@ class EventMetrics:
]
must.append({"bool": {"should": should}})
es_req = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": aggs,
}
return {"bool": {"must": must}}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
return es_res.get("aggregations")
def get_tasks_metrics(
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
@@ -426,12 +481,12 @@ class EventMetrics:
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": {"bool": {"must": self._task_conditions(task_id)}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": EventSettings.max_es_buckets,
"order": {"_key": "asc"},
}
}

View File

@@ -4,6 +4,7 @@ import attr
import jsonmodels.models
import jwt
from elasticsearch import Elasticsearch
from jwt.algorithms import get_default_algorithms
from apiserver.bll.event.event_common import (
check_empty_data,
@@ -67,6 +68,9 @@ class EventsIterator:
task_id: str,
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
es_req = {
"query": query,
@@ -74,11 +78,7 @@ class EventsIterator:
with translate_errors_context(), TimingContext("es", "count_task_events"):
es_result = count_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
return es_result["count"]
@@ -115,11 +115,7 @@ class EventsIterator:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
@@ -139,11 +135,7 @@ class EventsIterator:
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
@@ -188,7 +180,7 @@ class Scroll(jsonmodels.models.Base):
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
).decode()
)
@classmethod
def from_scroll_id(cls, scroll_id: str):
@@ -199,6 +191,7 @@ class Scroll(jsonmodels.models.Base):
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
algorithms=get_default_algorithms(),
)
)
except jwt.PyJWTError:

View File

@@ -0,0 +1,56 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import EventType
from .history_sample_iterator import HistorySampleIterator, VariantState
class HistoryDebugImageIterator(HistorySampleIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image)
def _get_extra_conditions(self) -> Sequence[dict]:
return [{"exists": {"field": "url"}}]
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
]
return {"bool": {"should": variants_conditions}}
def _process_event(self, event: dict) -> dict:
return event
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
# The min iteration is the lowest iteration that contains non-recycled image url
aggs = {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {"field": "url", "order": {"max_iter": "asc"}, "size": 1},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
}
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return min_iter, max_iter
return aggs, get_min_max_data

View File

@@ -0,0 +1,36 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from .event_common import EventType, uncompress_plot
from .history_sample_iterator import HistorySampleIterator, VariantState
class HistoryPlotIterator(HistorySampleIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_plot)
def _get_extra_conditions(self) -> Sequence[dict]:
return []
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
return {"terms": {"variant": [v.name for v in variants]}}
def _process_event(self, event: dict) -> dict:
uncompress_plot(event)
return event
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
# The min iteration is the lowest iteration that contains non-recycled image url
aggs = {
"last_iter": {"max": {"field": "iter"}},
"first_iter": {"min": {"field": "iter"}},
}
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
min_iter = int(variant_bucket["first_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return min_iter, max_iter
return aggs, get_min_max_data

View File

@@ -1,8 +1,10 @@
import abc
import operator
from typing import Sequence, Tuple, Optional
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Callable, Mapping
import attr
from boltons.iterutils import first
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
@@ -15,6 +17,7 @@ from apiserver.bll.event.event_common import (
EventType,
check_empty_data,
search_company_events,
get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
@@ -24,11 +27,12 @@ from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
metric: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugSampleHistoryState(Base, JsonSerializableMixin):
class HistorySampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
@@ -38,78 +42,120 @@ class DebugSampleHistoryState(Base, JsonSerializableMixin):
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class DebugSampleHistoryResult(object):
class HistorySampleResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class DebugSampleHistory:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
class HistorySampleIterator(abc.ABC):
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
self.es = es
self.event_type = event_type
self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState,
state_class=HistorySampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_debug_image(
def get_next_sample(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
) -> DebugSampleHistoryResult:
) -> HistorySampleResult:
"""
Get the debug image for next/prev variant on the current iteration
If does not exist then try getting image for the first/last variant from next/prev iteration
Get the sample for next/prev variant on the current iteration
If does not exist then try getting sample for the first/last variant from next/prev iteration
"""
res = DebugSampleHistoryResult(scroll_id=state_id)
res = HistorySampleResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
image = self._get_next_for_current_iteration(
event = self._get_next_for_current_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
if not image:
if not event:
return res
self._fill_res_and_update_state(image=image, res=res, state=state)
self._fill_res_and_update_state(event=event, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
self, event: dict, res: HistorySampleResult, state: HistorySampleState
):
state.variant = image["variant"]
state.iteration = image["iter"]
res.event = image
var_state = first(s for s in state.variant_states if s.name == state.variant)
self._process_event(event)
state.variant = event["variant"]
state.metric = event["metric"]
state.iteration = event["iter"]
res.event = event
var_state = first(
vs
for vs in state.variant_states
if vs.name == state.variant and vs.metric == state.metric
)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
@abc.abstractmethod
def _get_extra_conditions(self) -> Sequence[dict]:
pass
@abc.abstractmethod
def _process_event(self, event: dict) -> dict:
pass
@abc.abstractmethod
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
pass
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
metrics = bucketize(variants, key=attrgetter("metric"))
metrics_conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
self._get_variants_conditions(vs),
]
}
}
for metric, vs in metrics.items()
]
return {"bool": {"should": metrics_conditions}}
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
) -> Optional[dict]:
"""
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or image is found
Return None if no such variant or sample is found
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in state.variant_states
if cmp(var_state.name, state.variant)
for var_state in variants
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
and var_state.min_iteration <= state.iteration
]
if not variants:
@@ -117,14 +163,14 @@ class DebugSampleHistory:
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"terms": {"variant": [v.name for v in variants]}},
{"term": {"iter": state.iteration}},
{"exists": {"field": "url"}},
self._get_metric_variants_condition(variants),
*self._get_extra_conditions(),
]
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": {"variant": "desc" if navigate_earlier else "asc"},
"sort": [{"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
@@ -132,7 +178,10 @@ class DebugSampleHistory:
"es", "get_next_for_current_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
@@ -142,61 +191,58 @@ class DebugSampleHistory:
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
) -> Optional[dict]:
"""
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the image falls in invalid range are discarded
If no suitable image is found then None is returned
The variants for which the sample falls in invalid range are discarded
If no suitable sample is found then None is returned
"""
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"exists": {"field": "url"}},
]
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in state.variant_states
for var_state in variants
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = state.variant_states
variants = variants
if not variants:
return
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
must_conditions = [
{"term": {"task": state.task}},
self._get_metric_variants_condition(variants),
{"range": {"iter": {range_operator: state.iteration}}},
*self._get_extra_conditions(),
]
must_conditions.append({"bool": {"should": variants_conditions}})
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
es_req = {
"size": 1,
"sort": [{"iter": order}, {"variant": order}],
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
@@ -205,7 +251,7 @@ class DebugSampleHistory:
return hits[0]["_source"]
def get_debug_image_for_variant(
def get_sample_for_variant(
self,
company_id: str,
task: str,
@@ -214,36 +260,50 @@ class DebugSampleHistory:
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
) -> DebugSampleHistoryResult:
navigate_current_metric: bool = True,
) -> HistorySampleResult:
"""
Get the debug image for the requested iteration or the latest before it
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
res = HistorySampleResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: DebugSampleHistoryState):
def init_state(state_: HistorySampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
def validate_state(state_: HistorySampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
# fix old variant states:
for vs in state_.variant_states:
if vs.metric is None:
vs.metric = metric
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugSampleHistoryState
state: HistorySampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(s for s in state.variant_states if s.name == variant)
var_state = first(
vs
for vs in state.variant_states
if vs.name == variant and vs.metric == metric
)
if not var_state:
return res
@@ -254,7 +314,7 @@ class DebugSampleHistory:
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
*self._get_extra_conditions(),
]
if iteration is not None:
must_conditions.append(
@@ -276,12 +336,12 @@ class DebugSampleHistory:
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant"
"es", "get_history_sample_for_variant"
):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
event_type=self.event_type,
body=es_req,
)
@@ -290,86 +350,93 @@ class DebugSampleHistory:
return res
self._fill_res_and_update_state(
image=hits[0]["_source"], res=res, state=state
event=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations(
company_id=company_id, task=state.task, metric=state.metric
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
metrics = self._get_metric_variant_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.variant_states = [
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
for var_name, min_iter, max_iter in variant_iterations
VariantState(
metric=metric,
name=var_name,
min_iteration=min_iter,
max_iteration=max_iter,
)
for metric, variants in metrics.items()
for var_name, min_iter, max_iter in variants
]
def _get_variant_iterations(
self,
company_id: str,
task: str,
metric: str,
variants: Optional[Sequence[str]] = None,
) -> Sequence[Tuple[str, int, int]]:
@abc.abstractmethod
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
pass
def _get_metric_variant_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Tuple[str, str, int, int]]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"exists": {"field": "url"}},
*self._get_extra_conditions(),
]
if variants:
must.append({"terms": {"variant": variants}})
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"query": query,
"aggs": {
"variants": {
# all variants that sent debug images
"metrics": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"variants": {
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
"aggs": min_max_aggs,
}
},
}
},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations"
"es", "get_history_sample_iterations"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
es_res = search_company_events(body=es_req, **search_args,)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
min_iter, max_iter = get_min_max_data(variant_bucket)
return variant, min_iter, max_iter
return [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(
es_res, ("aggregations", "variants", "buckets")
return {
metric_bucket["key"]: [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
]
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
]
}

View File

@@ -0,0 +1,53 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import EventType
from .metric_events_iterator import MetricEventsIterator, VariantState
class MetricDebugImagesIterator(MetricEventsIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image)
def _get_extra_conditions(self) -> Sequence[dict]:
return [{"exists": {"field": "url"}}]
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
aggs = {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
}
def fill_variant_state_data(variant_bucket: dict, state: VariantState):
"""If the image urls get recycled then fill the last_invalid_iteration field"""
top_iter_url = nested_get(variant_bucket, ("urls", "buckets"))[0]
iters = nested_get(top_iter_url, ("iters", "hits", "hits"))
if len(iters) > 1:
state.last_invalid_iteration = nested_get(iters[1], ("_source", "iter"))
return aggs, fill_variant_state_data
def _process_event(self, event: dict) -> dict:
return event
def _get_same_variant_events_order(self) -> dict:
return {"url": {"order": "desc"}}

View File

@@ -1,8 +1,9 @@
import abc
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping
from typing import Sequence, Tuple, Optional, Mapping, Callable
import attr
import dpath
@@ -18,9 +19,10 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition,
get_metric_variants_condition, get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
@@ -49,25 +51,24 @@ class TaskScrollState(Base):
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
class MetricEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
class MetricEventsResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
class MetricEventsIterator:
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
self.es = es
self.event_type = event_type
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
state_class=MetricEventsScrollState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
@@ -80,14 +81,14 @@ class DebugImagesIterator:
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return DebugImagesResult()
) -> MetricEventsResult:
if check_empty_data(self.es, company_id, self.event_type):
return MetricEventsResult()
def init_state(state_: DebugImageEventsScrollState):
def init_state(state_: MetricEventsScrollState):
state_.tasks = self._init_task_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
def validate_state(state_: MetricEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
@@ -99,7 +100,13 @@ class DebugImagesIterator:
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
res = MetricEventsResult(next_scroll_id=state.id)
specific_variants_requested = any(
variants
for t, metrics in task_metrics.items()
if metrics
for m, variants in metrics.items()
)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list(
pool.map(
@@ -108,6 +115,7 @@ class DebugImagesIterator:
company_id=company_id,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
specific_variants_requested=specific_variants_requested,
),
state.tasks,
)
@@ -118,11 +126,11 @@ class DebugImagesIterator:
def _reinit_outdated_task_states(
self,
company_id,
state: DebugImageEventsScrollState,
state: MetricEventsScrollState,
task_metrics: Mapping[str, dict],
):
"""
Determine the metrics for which new debug image events were added
Determine the metrics for which new event_type events were added
since their states were initialized and re-init these states
"""
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
@@ -132,7 +140,7 @@ class DebugImagesIterator:
def get_last_update_times_for_task_metrics(
task: Task,
) -> Mapping[str, datetime]:
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
"""For metrics that reported event_type events get mapping of the metric name to the last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return {}
@@ -140,10 +148,10 @@ class DebugImagesIterator:
requested_metrics = task_metrics[task.id]
return {
stats.metric: stats.event_stats_by_type[
self.EVENT_TYPE.value
self.event_type.value
].last_update
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
if self.event_type.value in stats.event_stats_by_type
and (not requested_metrics or stats.metric in requested_metrics)
}
@@ -213,18 +221,35 @@ class DebugImagesIterator:
for task, metric_states in zip(task_metrics, task_metric_states)
]
@abc.abstractmethod
def _get_extra_conditions(self) -> Sequence[dict]:
pass
@abc.abstractmethod
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
pass
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
for the variants that reported any event_type events
"""
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
if metrics:
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs()
es_req: dict = {
"size": 0,
"query": query,
@@ -232,7 +257,7 @@ class DebugImagesIterator:
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
@@ -240,29 +265,10 @@ class DebugImagesIterator:
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
**({"aggs": variant_state_aggs} if variant_state_aggs else {}),
},
},
}
@@ -270,22 +276,18 @@ class DebugImagesIterator:
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
return []
def init_variant_state(variant: dict):
"""
Return new variant state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantState(variant=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
if fill_variant_state_data:
fill_variant_state_data(variant, state)
return state
return [
@@ -300,12 +302,21 @@ class DebugImagesIterator:
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
@abc.abstractmethod
def _process_event(self, event: dict) -> dict:
pass
@abc.abstractmethod
def _get_same_variant_events_order(self) -> dict:
pass
def _get_task_metric_events(
self,
task_state: TaskScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
specific_variants_requested: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
@@ -321,7 +332,7 @@ class DebugImagesIterator:
must_conditions = [
{"term": {"task": task_state.task}},
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
{"exists": {"field": "url"}},
*self._get_extra_conditions(),
]
range_condition = None
@@ -332,6 +343,8 @@ class DebugImagesIterator:
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
metrics_count = len(task_state.metrics)
max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count))
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
@@ -346,20 +359,20 @@ class DebugImagesIterator:
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"url": {"order": "desc"}}
"sort": self._get_same_variant_events_order()
}
}
},
@@ -370,9 +383,12 @@ class DebugImagesIterator:
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
with translate_errors_context(), TimingContext("es", "_get_task_metric_events"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
if "aggregations" not in es_res:
return task_state.task, []
@@ -382,18 +398,26 @@ class DebugImagesIterator:
for m in task_state.metrics
for v in m.variants
}
allow_uninitialized = (
False
if specific_variants_requested
else config.get(
"services.events.events_retrieval.debug_images.allow_uninitialized_variants",
False,
)
)
def is_valid_event(event: dict) -> bool:
key = event.get("metric"), event.get("variant")
if key not in invalid_iterations:
return False
return allow_uninitialized
max_invalid = invalid_iterations[key]
return max_invalid is None or event.get("iter") > max_invalid
def get_iteration_events(it_: dict) -> Sequence:
return [
ev["_source"]
self._process_event(ev["_source"])
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")

View File

@@ -0,0 +1,25 @@
from typing import Sequence
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from .event_common import EventType, uncompress_plot
from .metric_events_iterator import MetricEventsIterator
class MetricPlotsIterator(MetricEventsIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_plot)
def _get_extra_conditions(self) -> Sequence[dict]:
return []
def _get_variant_state_aggs(self):
return None, None
def _process_event(self, event: dict) -> dict:
uncompress_plot(event)
return event
def _get_same_variant_events_order(self) -> dict:
return {"timestamp": {"order": "desc"}}

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Callable, Tuple
from typing import Callable, Tuple, Sequence, Dict
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
@@ -128,3 +128,33 @@ class ModelBLL:
)
return unarchived
@classmethod
def get_model_stats(
cls, company: str, model_ids: Sequence[str],
) -> Dict[str, dict]:
if not model_ids:
return {}
result = Model.aggregate(
[
{
"$match": {
"company": {"$in": [None, "", company]},
"_id": {"$in": model_ids},
}
},
{
"$addFields": {
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{
"$project": {"labels_count": 1},
},
]
)
return {
r.pop("_id"): r
for r in result
}

View File

@@ -6,6 +6,7 @@ from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model import EntityVisibility
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -42,6 +43,8 @@ class _TagsCache:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project__in=project_ids_with_children([project]))
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)

View File

@@ -17,6 +17,7 @@ from typing import (
Any,
)
from boltons.iterutils import partition
from mongoengine import Q, Document
from apiserver import database
@@ -62,20 +63,24 @@ class ProjectBLL:
source=source_id
)
source = Project.get(company, source_id)
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
if destination_id:
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
else:
destination = None
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
if destination:
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
moved_entities = 0
for entity_type in (Task, Model):
@@ -146,10 +151,8 @@ class ProjectBLL:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
if (
new_parent
and project_id == new_parent.id
or project_id in new_parent.path
if new_parent and (
project_id == new_parent.id or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
@@ -180,6 +183,7 @@ class ProjectBLL:
if new_location != old_location:
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
fields["name"] = new_name
fields["basename"] = new_name.split("/")[-1]
fields["last_update"] = datetime.utcnow()
updated = project.update(upsert=False, **fields)
@@ -222,6 +226,7 @@ class ProjectBLL:
user=user,
company=company,
name=name,
basename=name.split("/")[-1],
description=description,
tags=tags,
system_tags=system_tags,
@@ -325,6 +330,7 @@ class ProjectBLL:
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
@@ -349,7 +355,10 @@ class ProjectBLL:
# count tasks per project per status
{
"$match": cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
company=company_id,
project_ids=project_ids,
filter_=filter_,
users=users,
)
},
ensure_valid_fields(),
@@ -454,22 +463,39 @@ class ProjectBLL:
)
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
def get_state_filter() -> dict:
def add_state_to_filter(f: Mapping[str, Any]) -> Mapping[str, Any]:
if not specific_state:
return {}
return f
f = f or {}
new_f = {k: v for k, v in f.items() if k != "system_tags"}
system_tags = [
tag
for tag in f.get("system_tags", [])
if tag
not in (
EntityVisibility.archived.value,
f"-{EntityVisibility.archived.value}",
)
]
if specific_state == EntityVisibility.archived:
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
system_tags.append(EntityVisibility.archived.value)
else:
system_tags.append(f"-{EntityVisibility.archived.value}")
new_f["system_tags"] = system_tags
return new_f
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
**cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
),
**get_state_filter(),
}
"$match": cls.get_match_conditions(
company=company_id,
project_ids=project_ids,
filter_=add_state_to_filter(filter_),
users=users,
)
},
ensure_valid_fields(),
{
@@ -504,6 +530,40 @@ class ProjectBLL:
aggregated[pid] = reduce(func, relevant_data)
return aggregated
@classmethod
def get_dataset_stats(
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
task_runtime_pipeline = [
{
"$match": {
**cls.get_match_conditions(
company=company,
project_ids=project_ids,
users=users,
filter_={
"system_tags": [f"-{EntityVisibility.archived.value}"]
},
),
"runtime": {"$exists": True, "$gt": {}},
}
},
{"$project": {"project": 1, "runtime": 1, "last_update": 1}},
{"$sort": {"project": 1, "last_update": 1}},
{"$group": {"_id": "$project", "runtime": {"$last": "$runtime"}}},
]
return {
r["_id"]: {
"file_count": r["runtime"].get("ds_file_count", 0),
"total_size": r["runtime"].get("ds_total_size", 0),
}
for r in Task.aggregate(task_runtime_pipeline)
}
@classmethod
def get_project_stats(
cls,
@@ -511,13 +571,21 @@ class ProjectBLL:
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
search_hidden: bool = False,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
user_active_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = (
_get_sub_projects(project_ids, _only=("id", "name"))
_get_sub_projects(
project_ids,
_only=("id", "name"),
search_hidden=search_hidden,
allowed_ids=user_active_project_ids,
)
if include_children
else {}
)
@@ -529,6 +597,7 @@ class ProjectBLL:
project_ids=list(project_ids_with_children),
specific_state=specific_state,
filter_=filter_,
users=users,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
@@ -695,7 +764,7 @@ class ProjectBLL:
users: Sequence[str],
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
) -> Sequence[str]:
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids where user created any tasks including all the parents of these projects
If project ids are specified then filter the results by these project ids
@@ -719,13 +788,16 @@ class ProjectBLL:
res = list(res)
if not res:
return res
return res, res
ids_with_parents = _ids_with_parents(res)
if project_ids:
return [pid for pid in ids_with_parents if pid in project_ids]
user_active_project_ids = _ids_with_parents(res)
filtered_ids = (
list(set(user_active_project_ids) & set(project_ids))
if project_ids
else list(user_active_project_ids)
)
return ids_with_parents
return filtered_ids, user_active_project_ids
@classmethod
def get_task_parents(
@@ -740,10 +812,13 @@ class ProjectBLL:
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
@@ -772,7 +847,8 @@ class ProjectBLL:
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@@ -790,32 +866,45 @@ class ProjectBLL:
@staticmethod
def get_match_conditions(
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any],
users: Sequence[str],
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
if users:
conditions["user"] = {"$in": users}
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
for field, field_filter in filter_.items():
if not (
field_filter
and isinstance(field_filter, list)
and all(isinstance(t, str) for t in field_filter)
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
conditions[field] = {
**({"$in": include} if include else {}),
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
}
return conditions
@classmethod
def calc_own_contents(
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
cls,
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
Returns the amount of task/models per requested project
@@ -828,7 +917,10 @@ class ProjectBLL:
pipeline = [
{
"$match": cls.get_match_conditions(
company=company, project_ids=project_ids, filter_=filter_
company=company,
project_ids=project_ids,
filter_=filter_,
users=users,
)
},
{"$project": {"project": 1}},

View File

@@ -13,7 +13,7 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from apiserver.timing_context import TimingContext
from .sub_projects import _ids_with_children
@@ -32,22 +32,28 @@ class DeleteProjectResult:
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
).count()
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in (Task, Model):
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).count()
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
name = f"non_archived_{cls.__name__.lower()}s"
if not is_pipeline:
ret[name] = cls.objects(**query).count()
else:
ret[name] = (
cls.objects(**query, type=TaskType.controller).count()
if cls == Task
else 0
)
return ret
@@ -56,23 +62,30 @@ def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
if not force:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
if not is_pipeline:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(**query).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
else:
non_archived = Task.objects(**query, type=TaskType.controller).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
raise errors.bad_request.ProjectHasTasks(
"please archive all the runs inside the project", id=project_id
)
if not delete_contents:
with TimingContext("mongo", "update_children"):

View File

@@ -4,6 +4,7 @@ from typing import Tuple, Optional, Sequence, Mapping
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
name_separator = "/"
@@ -50,6 +51,7 @@ def _ensure_project(
created=now,
last_update=now,
name=name,
basename=name.split("/")[-1],
**(creation_params or dict(description="")),
)
parent = _ensure_project(company, user, location, creation_params=creation_params)
@@ -100,12 +102,21 @@ def _get_writable_project_from_name(
def _get_sub_projects(
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
project_ids: Sequence[str],
_only: Sequence[str] = ("id", "path"),
search_hidden=True,
allowed_ids: Sequence[str] = None,
) -> Mapping[str, Sequence[Project]]:
"""
Return the list of child projects of all the levels for the parent project ids
"""
qs = Project.objects(path__in=project_ids)
query = dict(path__in=project_ids)
if not search_hidden:
query["system_tags__nin"] = [EntityVisibility.hidden.value]
if allowed_ids:
query["id__in"] = allowed_ids
qs = Project.objects(**query)
if _only:
_only = set(_only) | {"path"}
qs = qs.only(*_only)

View File

@@ -7,7 +7,7 @@ from elasticsearch import Elasticsearch
from apiserver import database
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.queue.queue_metrics import QueueMetrics
from apiserver.bll.queue.queue_metrics import QueueMetrics, MetricsRefresher
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@@ -50,8 +50,25 @@ class QueueBLL(object):
queue.save()
return queue
def get_by_name(
self, company_id: str, queue_name: str, only: Optional[Sequence[str]] = None,
) -> Queue:
qs = Queue.objects(name=queue_name, company=company_id)
if only:
qs = qs.only(*only)
return qs.first()
@staticmethod
def _get_task_entries_projection(max_task_entries: int) -> dict:
return dict(slice__entries=max_task_entries)
def get_by_id(
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
self,
company_id: str,
queue_id: str,
only: Optional[Sequence[str]] = None,
max_task_entries: int = None,
) -> Queue:
"""
Get queue by id
@@ -62,6 +79,8 @@ class QueueBLL(object):
qs = Queue.objects(**query)
if only:
qs = qs.only(*only)
if max_task_entries:
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
queue = qs.first()
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
@@ -130,6 +149,7 @@ class QueueBLL(object):
self,
company_id: str,
query_dict: dict,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
@@ -138,6 +158,9 @@ class QueueBLL(object):
company=company_id,
parameters=query_dict,
query_dict=query_dict,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
@@ -145,6 +168,7 @@ class QueueBLL(object):
self,
company_id: str,
query_dict: dict,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""
@@ -156,6 +180,9 @@ class QueueBLL(object):
company=company_id,
query_dict=query_dict,
override_projection=projection,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
@@ -288,3 +315,25 @@ class QueueBLL(object):
)
return new_position
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next(
Queue.aggregate(
[
{
"$match": {
"company": {"$in": [None, "", company]},
"_id": queue_id,
}
},
{"$project": {"count": {"$size": "$entries"}}},
]
),
None,
)
if res is None:
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
return int(res.get("count"))
MetricsRefresher.start(queue_metrics=QueueBLL().metrics)

View File

@@ -1,8 +1,10 @@
import json
from collections import defaultdict
from datetime import datetime
from time import sleep
from typing import Sequence
import elasticsearch.helpers
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.es_factory import es_factory
@@ -11,25 +13,31 @@ from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
_conf = config.get("services.queues")
_queue_metrics_key_pattern = "queue_metrics_{queue}"
redis = redman.connection("apiserver")
class EsKeys:
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
QUEUE_FIELD = "queue"
class QueueMetrics:
class EsKeys:
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
QUEUE_FIELD = "queue"
def __init__(self, es: Elasticsearch):
self.es = es
@staticmethod
def _queue_metrics_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"queue_metrics_{company_id}_"
return f"queue_metrics_{company_id.lower()}_"
@staticmethod
def _get_es_index_suffix():
@@ -49,7 +57,7 @@ class QueueMetrics:
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
return total_waiting_in_secs / len(entries)
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> bool:
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> int:
"""
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
:return: True if the write to es was successful, false otherwise
@@ -63,23 +71,22 @@ class QueueMetrics:
def make_doc(queue: Queue) -> dict:
entries = [e for e in queue.entries if e.added]
return dict(
_index=es_index,
_source={
self.EsKeys.TIMESTAMP_FIELD: timestamp,
self.EsKeys.QUEUE_FIELD: queue.id,
self.EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(
entries
),
self.EsKeys.QUEUE_LENGTH_FIELD: len(entries),
},
)
return {
EsKeys.TIMESTAMP_FIELD: timestamp,
EsKeys.QUEUE_FIELD: queue.id,
EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(entries),
EsKeys.QUEUE_LENGTH_FIELD: len(entries),
}
actions = list(map(make_doc, queues))
logged = 0
for q in queues:
queue_doc = make_doc(q)
self.es.index(index=es_index, body=queue_doc)
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
redis.set(redis_key, json.dumps(queue_doc))
logged += 1
es_res = elasticsearch.helpers.bulk(self.es, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return logged
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
query = dict(company=company_id)
@@ -90,8 +97,7 @@ class QueueMetrics:
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
body=es_req,
index=f"{self._queue_metrics_prefix_for_company(company_id)}*", body=es_req,
)
@classmethod
@@ -105,13 +111,13 @@ class QueueMetrics:
return {
"dates": {
"date_histogram": {
"field": cls.EsKeys.TIMESTAMP_FIELD,
"field": EsKeys.TIMESTAMP_FIELD,
"fixed_interval": f"{interval}s",
"min_doc_count": 1,
},
"aggs": {
"queues": {
"terms": {"field": cls.EsKeys.QUEUE_FIELD},
"terms": {"field": EsKeys.QUEUE_FIELD},
"aggs": cls._get_top_waiting_agg(),
}
},
@@ -128,13 +134,13 @@ class QueueMetrics:
"top_avg_waiting": {
"top_hits": {
"sort": [
{cls.EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
{cls.EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
{EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
{EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
],
"_source": {
"includes": [
cls.EsKeys.WAITING_TIME_FIELD,
cls.EsKeys.QUEUE_LENGTH_FIELD,
EsKeys.WAITING_TIME_FIELD,
EsKeys.QUEUE_LENGTH_FIELD,
]
},
"size": 1,
@@ -149,6 +155,7 @@ class QueueMetrics:
to_date: float,
interval: int,
queue_ids: Sequence[str],
refresh: bool = False,
) -> dict:
"""
Get the company queue metrics in the specified time range.
@@ -158,7 +165,8 @@ class QueueMetrics:
In case no queue ids are specified the avg across all the
company queues is calculated for each metric
"""
# self._log_current_metrics(company, queue_ids=queue_ids)
if refresh:
self._log_current_metrics(company_id, queue_ids=queue_ids)
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
@@ -256,7 +264,47 @@ class QueueMetrics:
continue
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
queue_metrics[queue_data["key"]] = {
"queue_length": res[cls.EsKeys.QUEUE_LENGTH_FIELD],
"avg_waiting_time": res[cls.EsKeys.WAITING_TIME_FIELD],
"queue_length": res[EsKeys.QUEUE_LENGTH_FIELD],
"avg_waiting_time": res[EsKeys.WAITING_TIME_FIELD],
}
return queue_metrics
class MetricsRefresher:
threads = ThreadsManager()
@classproperty
def watch_interval_sec(self):
return _conf.get("metrics_refresh_interval_sec", 300)
@classmethod
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
def start(cls, queue_metrics: QueueMetrics):
if not cls.watch_interval_sec:
return
sleep(10)
while not ThreadsManager.terminating:
try:
for queue in Queue.objects():
timestamp = es_factory.get_timestamp_millis()
doc_time = 0
try:
redis_key = _queue_metrics_key_pattern.format(queue=queue.id)
data = redis.get(redis_key)
if data:
queue_doc = json.loads(data)
doc_time = int(queue_doc.get(EsKeys.TIMESTAMP_FIELD))
except Exception as ex:
log.exception(
f"Error reading queue metrics data for queue {queue.id}: {str(ex)}"
)
if (
not doc_time
or (timestamp - doc_time) > cls.watch_interval_sec * 1000
):
queue_metrics.log_queue_metrics_to_es(queue.company, [queue])
except Exception as ex:
log.exception(f"Failed collecting queue metrics: {str(ex)}")
sleep(60)

View File

@@ -121,18 +121,31 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
nested_set(fields, new_path, new_param)
nested_delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = fields.get(param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
ParameterKeyEscaper.escape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
def ensure_non_empty(k: str, desc: str) -> str:
if not k:
raise errors.bad_request.ValidationError(
f"Empty {desc} name is not allowed"
)
return k
params = fields.get("hyperparams")
if params:
escaped_params = {
ParameterKeyEscaper.escape(ensure_non_empty(key, "section")): {
ParameterKeyEscaper.escape(ensure_non_empty(k, "parameter")): v
for k, v in value.items()
}
fields[param_field] = escaped_params
for key, value in params.items()
}
fields["hyperparams"] = escaped_params
params = fields.get("configuration")
if params:
escaped_params = {
ParameterKeyEscaper.escape(ensure_non_empty(key, "configuration")): value
for key, value in params.items()
}
fields["configuration"] = escaped_params
def params_unprepare_from_saved(fields, copy_to_legacy=False):
@@ -186,7 +199,7 @@ def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", "configuration"),
("execution.docker_cmd", "container")
("execution.docker_cmd", "container"),
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]

View File

@@ -1,64 +1,19 @@
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
from typing import Sequence, Set, Tuple
import attr
from boltons.iterutils import partition
from mongoengine import QuerySet, Document
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import PlotFields
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.timing_context import TimingContext
event_bll = EventBLL()
T = TypeVar("T", bound=Document)
class DocumentGroup(List[T]):
"""
Operate on a list of documents as if they were a query result
"""
def __init__(self, document_type: Type[T], documents: Iterable[T]):
super(DocumentGroup, self).__init__(documents)
self.type = document_type
@property
def ids(self) -> Set[str]:
return {obj.id for obj in self}
def objects(self, *args, **kwargs) -> QuerySet:
return self.type.objects(id__in=self.ids, *args, **kwargs)
class TaskOutputs(Generic[T]):
"""
Split task outputs of the same type by the ready state
"""
published: DocumentGroup[T]
draft: DocumentGroup[T]
def __init__(
self,
is_published: Callable[[T], bool],
document_type: Type[T],
children: Iterable[T],
):
"""
:param is_published: predicate returning whether items is considered published
:param document_type: type of output
:param children: output documents
"""
self.published, self.draft = map(
lambda x: DocumentGroup(document_type, x),
partition(children, key=is_published),
)
@attr.s(auto_attribs=True)
@@ -124,32 +79,18 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
metrics = event_bll.get_metrics_and_variants(
company_id=company, task_id=task, event_type=EventType.metrics_image
)
if not metrics:
return set()
task_metrics = {task: {m: [] for m in metrics}}
scroll_id = None
after_key = None
urls = set()
while True:
res = event_bll.debug_images_iterator.get_task_events(
res, after_key = event_bll.get_debug_image_urls(
company_id=company,
task_metrics=task_metrics,
iter_count=10,
state_id=scroll_id,
task_id=task,
after_key=after_key,
)
if not res.metric_events or not any(
iterations for _, iterations in res.metric_events
):
urls.update(res)
if not after_key:
break
scroll_id = res.next_scroll_id
for task, iterations in res.metric_events:
urls.update(ev.get("url") for it in iterations for ev in it["events"])
urls.discard({None})
return urls
@@ -166,7 +107,9 @@ def cleanup_task(
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
models = verify_task_children_and_ouptuts(task, force)
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
task, force
)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls:
@@ -178,28 +121,37 @@ def cleanup_task(
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
model_urls = {
m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
}
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
if update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id
)
else:
updated_children = 0
if models.draft and delete_output_models:
with TimingContext("mongo", "delete_models"):
deleted_models = models.draft.objects().delete()
else:
deleted_models = 0
deleted_models = 0
updated_models = 0
for models, allow_delete in ((draft_models, True), (published_models, False)):
if not models:
continue
if delete_output_models and allow_delete:
deleted_models += Model.objects(
id__in=[m.id for m in models if m.id not in in_use_model_ids]
).delete()
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
continue
if models.published and update_children:
with TimingContext("mongo", "update_task_models"):
updated_models = models.published.objects().update(task=deleted_task_id)
else:
updated_models = 0
if update_children:
updated_models += Model.objects(id__in=[m.id for m in models]).update(
task=deleted_task_id
)
else:
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
@@ -217,7 +169,9 @@ def cleanup_task(
)
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
def verify_task_children_and_ouptuts(
task, force: bool
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
@@ -230,49 +184,42 @@ def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Mod
children=published_children_count,
)
with TimingContext("mongo", "get_task_models"):
models = TaskOutputs(
attrgetter("ready"),
Model,
Model.objects(task=task.id).only("id", "task", "ready"),
model_fields = ["id", "ready", "uri"]
published_models, draft_models = partition(
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
)
if not force and published_models:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(published_models),
)
if not force and models.published:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(models.published),
)
if task.models and task.models.output:
with TimingContext("mongo", "get_task_output_model"):
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=output_model.id,
)
models.published.append(output_model)
else:
models.draft.append(output_model)
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids).only(*model_fields):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=output_model.id,
)
published_models.append(output_model)
else:
draft_models.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models"
in_use_model_ids = {}
if draft_models:
model_ids = {m.id for m in draft_models}
dependent_tasks = Task.objects(models__input__model__in=list(model_ids)).only(
"id", "models"
)
in_use_model_ids = model_ids & {
m.model
for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
input_models = {
m.model
for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
}
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)
)
}
return models
return published_models, draft_models, in_use_model_ids

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Callable, Any, Tuple, Union
from typing import Callable, Any, Tuple, Union, Sequence
from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL
@@ -25,6 +25,7 @@ from apiserver.database.model.task.task import (
)
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
queue_bll = QueueBLL()
@@ -83,10 +84,7 @@ def unarchive_task(
def dequeue_task(
task_id: str,
company_id: str,
status_message: str,
status_reason: str,
task_id: str, company_id: str, status_message: str, status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
@@ -94,10 +92,7 @@ def dequeue_task(
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=status_message,
status_reason=status_reason,
task, company_id, status_message=status_message, status_reason=status_reason,
)
return 1, res
@@ -108,9 +103,23 @@ def enqueue_task(
queue_id: str,
status_message: str,
status_reason: str,
queue_name: str = None,
validate: bool = False,
force: bool = False,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
)
if not queue:
queue = queue_bll.create(company_id=company_id, name=queue_name)
queue_id = queue.id
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
@@ -155,6 +164,30 @@ def enqueue_task(
return 1, res
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
try:
collection_name = Task._get_collection_name()
trash_collection_name = f"{collection_name}__trash"
Task.aggregate(
[
{"$match": {"_id": {"$in": tasks}}},
{
"$merge": {
"into": trash_collection_name,
"on": "_id",
"whenMatched": "replace",
"whenNotMatched": "insert",
}
},
],
allow_disk_use=True,
)
except Exception as ex:
log.error(f"Error copying tasks to trash {str(ex)}")
return Task.objects(id__in=tasks).delete()
def delete_task(
task_id: str,
company_id: str,
@@ -200,18 +233,12 @@ def delete_task(
)
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.save()
else:
task.delete()
task.delete()
update_project_time(task.project)
return 1, task, cleanup_res

View File

@@ -76,7 +76,7 @@ class WorkerBLL:
raise bad_request.InvalidUserId(**query)
company = Company.objects(id=company_id).only("id", "name").first()
if not company:
raise server_error.InternalError("invalid company", company=company_id)
raise bad_request.InvalidId("invalid company", company=company_id)
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
if len(queue_objs) < len(queues):
@@ -143,7 +143,7 @@ class WorkerBLL:
self._log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=report.worker,
worker=entry.key,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
@@ -189,7 +189,10 @@ class WorkerBLL:
self._save_worker(entry)
def get_all(
self, company_id: str, last_seen: Optional[int] = None
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@@ -210,16 +213,26 @@ class WorkerBLL:
if w.last_activity_time.replace(tzinfo=None) >= ref_time
]
if tags:
include = {t for t in tags if not t.startswith("-")}
exclude = {t[1:] for t in tags if t.startswith("-")}
workers = [
w
for w in workers
if (not include or any(t in include for t in w.tags))
and (not exclude or all(t not in exclude for t in w.tags))
]
return workers
def get_all_with_projection(
self, company_id: str, last_seen: int
self, company_id: str, last_seen: int, tags: Sequence[str] = None
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen),
self.get_all(company_id=company_id, last_seen=last_seen, tags=tags),
)
)
@@ -352,10 +365,15 @@ class WorkerBLL:
self, company: str, user: str = "*", worker_id: str = "*"
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
entries = []
match = self._get_worker_key(company, user, worker_id)
with TimingContext("redis", "workers_get_all"):
res = self.redis.scan_iter(match)
return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
for r in self.redis.scan_iter(match):
data = self.redis.get(r)
if data:
entries.append(WorkerEntry.from_json(data))
return entries
@staticmethod
def _get_es_index_suffix():

View File

@@ -20,7 +20,7 @@ class WorkerStats:
@staticmethod
def worker_stats_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"worker_stats_{company_id}_"
return f"worker_stats_{company_id.lower()}_"
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(

View File

@@ -146,4 +146,11 @@
max_backoff_sec: 5
}
getting_started_info {
"agentName": "clearml",
"configure": "clearml-init",
"install": "pip install clearml",
"packageName": "clearml"
}
}

View File

@@ -12,12 +12,23 @@ events_retrieval {
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4
# If set then max_metrics_count and max_variants_count are calculated dynamically on user data
dynamic_metrics_count: true
# The percentage from the ES aggs limit (10000) to use for the max_metrics and max_variants calculation
dynamic_metrics_count_threshold: 80
# the max amount of metrics to aggregate on
max_metrics_count: 100
# the max amount of variants to aggregate on
max_variants_count: 100
debug_images {
# Allow to return the debug images for the variants with uninitialized valid iterations border
allow_uninitialized_variants: true
}
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"

View File

@@ -0,0 +1,5 @@
{
metrics_before_from_date: 3600
# interval in seconds to update queue metrics. Put 0 to disable
metrics_refresh_interval_sec: 300
}

View File

@@ -166,7 +166,10 @@ class MongoEngineErrorsHandler(object):
@classmethod
@throws_default_error(errors.server_error.InternalError)
def invalid_query_error(cls, e, message, **_):
pass
if e.args:
inner = e.args[0]
if isinstance(inner, LookUpError):
cls.lookup_error(inner, message)
@contextmanager

View File

@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
class EntityVisibility(Enum):
active = "active"
archived = "archived"
hidden = "hidden"

View File

@@ -648,6 +648,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
projection_fields: dict = None,
ret_params: dict = None,
):
"""
@@ -684,6 +685,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
projection_fields=projection_fields,
ret_params=ret_params,
)
@@ -704,6 +706,45 @@ class GetMixin(PropsMixin):
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
)
@classmethod
def get_count(
cls: Union["GetMixin", Document],
company,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
query: Q = None,
allow_public=False,
) -> int:
_query = cls._get_combined_query(
company=company,
query_dict=query_dict,
query_options=query_options,
query=query,
allow_public=allow_public,
)
return cls.objects(_query).count()
@classmethod
def _get_combined_query(
cls,
company,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
query: Q = None,
allow_public=False,
) -> Q:
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
company=company,
parameters_options=query_options,
allow_public=allow_public,
)
else:
q = cls._prepare_perm_query(company, allow_public=allow_public)
return (q & query) if query else q
@classmethod
def get_many(
cls,
@@ -715,6 +756,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
projection_fields: dict = None,
ret_params: dict = None,
):
"""
@@ -749,16 +791,13 @@ class GetMixin(PropsMixin):
if override_collation:
break
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
company=company,
parameters_options=query_options,
allow_public=allow_public,
)
else:
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
_query = cls._get_combined_query(
company=company,
query_dict=query_dict,
query_options=query_options,
query=query,
allow_public=allow_public,
)
if return_dicts:
data_getter = partial(
@@ -767,6 +806,7 @@ class GetMixin(PropsMixin):
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
projection_fields=projection_fields,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
@@ -777,6 +817,7 @@ class GetMixin(PropsMixin):
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
projection_fields=projection_fields,
)
@classmethod
@@ -801,6 +842,7 @@ class GetMixin(PropsMixin):
parameters=None,
override_projection=None,
override_collation=None,
projection_fields: dict = None,
):
"""
Fetch all documents matching a provided query.
@@ -843,6 +885,9 @@ class GetMixin(PropsMixin):
if exclude:
qs = qs.exclude(*exclude)
if projection_fields:
qs = qs.fields(**projection_fields)
if start is not None and size:
# add paging
qs = qs.skip(start).limit(size)
@@ -884,6 +929,7 @@ class GetMixin(PropsMixin):
parameters: dict = None,
override_projection: Collection[str] = None,
override_collation: dict = None,
projection_fields: dict = None,
) -> Sequence[dict]:
"""
Fetch all documents matching a provided query. For the first order by field
@@ -941,6 +987,9 @@ class GetMixin(PropsMixin):
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if projection_fields:
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]

View File

@@ -9,7 +9,7 @@ from apiserver.database.model.base import GetMixin
class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
pattern_fields=("name", "basename", "description"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
range_fields=("last_update",),
)
@@ -21,6 +21,7 @@ class Project(AttributedDocument):
"parent",
"path",
("company", "name"),
("company", "basename"),
{
"name": "%s.project.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$description"],
@@ -37,6 +38,7 @@ class Project(AttributedDocument):
min_length=3,
sparse=True,
)
basename = StrippedStringField(required=True)
description = StringField()
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True))

View File

@@ -175,6 +175,8 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"last_update",
"status_changed",
"models.input.model",
("company", "name"),
("company", "user"),

View File

@@ -3,8 +3,6 @@ from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
from typing import Sequence, Dict, Callable, Tuple, Any, Type
import dpath.path
from apiserver.apierrors import errors
from apiserver.database.props import PropsMixin
@@ -275,25 +273,26 @@ class ProjectionHelper(object):
norm_path = doc_cls.get_dpath_translated_path(path)
globlist = norm_path.strip(SEP).split(SEP)
obj_paths = self._cached_results_paths.get(id(obj))
if obj_paths is None:
obj_paths = self._cached_results_paths[id(obj)] = list(
dpath.path.paths(obj, dirs=True, skip=True)
)
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
def _search_and_replace(target: dict, p: Sequence[str]) -> Sequence[str]:
parent = None
target = obj
for part in p:
parent = target
target = target[part[0]]
if parent and factory:
parent[p[-1][0]] = factory(target)
return target
for idx, part in enumerate(p):
if isinstance(target, dict) and part in target:
parent = target
target = target[part]
elif isinstance(target, list) and part == "*":
return list(
chain.from_iterable(
_search_and_replace(t, p[idx + 1 :]) for t in target
)
)
else:
return []
return [search_and_replace(p) for p in paths]
if parent and factory:
parent[p[-1]] = factory(target)
return [target]
return _search_and_replace(obj, globlist)
def project(self, results, projection_func):
"""

View File

@@ -24,6 +24,7 @@ from typing import (
Callable,
)
from urllib.parse import unquote, urlparse
from uuid import uuid4
from zipfile import ZipFile, ZIP_BZIP2
import mongoengine
@@ -690,6 +691,19 @@ class PrePopulate:
continue
yield clean
@classmethod
def _generate_new_ids(
cls, reader: ZipFile, entity_files: Sequence
) -> Mapping[str, str]:
ids = {}
for entity_file in entity_files:
with reader.open(entity_file) as f:
for item in cls.json_lines(f):
orig_id = json.loads(item).get("_id")
if orig_id:
ids[orig_id] = str(uuid4()).replace("-", "")
return ids
@classmethod
def _import(
cls,
@@ -704,37 +718,46 @@ class PrePopulate:
Start from entities since event import will require the tasks already in DB
"""
event_file_ending = cls.events_file_suffix + ".json"
entity_files = (
entity_files = [
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
)
]
metadata = metadata or {}
old_to_new_ids = (
cls._generate_new_ids(reader, entity_files)
if metadata.get("new_ids")
else {}
)
tasks = []
for entity_file in entity_files:
with reader.open(entity_file) as f:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(f, full_name, company_id, user_id, metadata)
res = cls._import_entity(
f, full_name, company_id, user_id, metadata, old_to_new_ids
)
if res:
tasks = res
if sort_tasks_by_last_updated:
tasks = sorted(tasks, key=attrgetter("last_update"))
new_to_old_ids = {v: k for k, v in old_to_new_ids.items()}
for task in tasks:
old_task_id = new_to_old_ids.get(task.id, task.id)
events_file = first(
fi
for fi in reader.filelist
if fi.orig_filename.endswith(task.id + event_file_ending)
if fi.orig_filename.endswith(old_task_id + event_file_ending)
)
if not events_file:
continue
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, full_name, company_id, user_id)
cls._import_events(f, company_id, user_id, task.id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
@@ -746,6 +769,15 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _upgrade_project_data(project_data: dict) -> dict:
if not project_data.get("basename"):
name: str = project_data["name"]
_, _, basename = name.rpartition("/")
project_data["basename"] = basename
return project_data
@staticmethod
def _upgrade_model_data(model_data: dict) -> dict:
metadata_key = "metadata"
@@ -838,6 +870,7 @@ class PrePopulate:
company_id: str,
user_id: str,
metadata: Mapping[str, Any],
old_to_new_ids: Mapping[str, str] = None,
) -> Optional[Sequence[Task]]:
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
@@ -846,8 +879,14 @@ class PrePopulate:
data_upgrade_funcs: Mapping[Type, Callable] = {
cls.task_cls: cls._upgrade_task_data,
cls.model_cls: cls._upgrade_model_data,
cls.project_cls: cls._upgrade_project_data,
}
for item in cls.json_lines(f):
if old_to_new_ids:
for old_id, new_id in old_to_new_ids.items():
# replace ids only when they are standalone strings
# otherwise artifacts uris that contain old ids may get damaged
item = item.replace(f'"{old_id}"', f'"{new_id}"')
upgrade_func = data_upgrade_funcs.get(cls_)
if upgrade_func:
item = json.dumps(upgrade_func(json.loads(item)))
@@ -884,11 +923,15 @@ class PrePopulate:
return tasks
@classmethod
def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
def _import_events(
cls, f: IO[bytes], company_id: str, _, task_id: str
):
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
for ev in events:
ev["task"] = task_id
ev["company_id"] = company_id
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True
)

View File

@@ -0,0 +1,12 @@
from pymongo.collection import Collection
from pymongo.database import Database
def migrate_backend(db: Database):
projects: Collection = db["project"]
for doc in projects.find({"basename": None}):
name: str = doc["name"]
_, _, basename = name.rpartition("/")
projects.update_one(
{"_id": doc["_id"]}, {"$set": {"basename": basename}},
)

View File

@@ -21,7 +21,7 @@ nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt<2.0.0
pyjwt>=2.4.0
pymongo[srv]==3.12.0
python-rapidjson>=0.6.3
redis==3.5.3
@@ -31,4 +31,4 @@ requests>=2.13.0
semantic_version>=2.8.3,<3
six
tqdm
validators>=0.12.4
validators>=0.12.4

View File

@@ -262,6 +262,38 @@ get_credentials {
}
}
edit_credentials {
allow_roles = [ "*" ]
internal: false
"2.19" {
description: """Updates the label of the existing credentials for the authenticated user."""
request {
type: object
required: [ access_key ]
properties {
access_key {
type: string
description: Existing credentials key
}
label {
type: string
description: New credentials label
}
}
}
response {
type: object
properties {
updated {
description: "Number of credentials updated"
type: integer
enum: [0, 1]
}
}
}
}
}
revoke_credentials {
allow_roles = [ "*" ]
internal: false

View File

@@ -302,6 +302,48 @@ _definitions {
}
}
}
plots_response_task_metrics {
type: object
properties {
task {
type: string
description: Task ID
}
iterations {
type: array
items {
type: object
properties {
iter {
type: integer
description: Iteration number
}
events {
type: array
items {
type: object
description: Plot event
}
}
}
}
}
}
}
plots_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
metrics {
type: array
description: "Plot events grouped by tasks and iterations"
items {"$ref": "#/definitions/plots_response_task_metrics"}
}
}
}
debug_image_sample_response {
type: object
properties {
@@ -323,6 +365,27 @@ _definitions {
}
}
}
plot_sample_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
}
event {
type: object
description: "Plot event"
}
min_iteration {
type: integer
description: "minimal valid iteration for the variant"
}
max_iteration {
type: integer
description: "maximal valid iteration for the variant"
}
}
}
}
add {
"2.1" {
@@ -486,6 +549,41 @@ debug_images {
}
}
}
plots {
"2.20" {
description: "Get plot events for the requested amount of iterations per each task"
request {
type: object
required: [
metrics
]
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/task_metric_variants" }
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
navigate_earlier {
type: boolean
description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True"
}
refresh {
type: boolean
description: "If set then scroll will be moved to the latest iterations. The default is False"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {"$ref": "#/definitions/plots_response"}
}
}
get_debug_image_sample {
"2.12": {
description: "Return the debug image per metric and variant for the provided iteration"
@@ -521,6 +619,13 @@ get_debug_image_sample {
}
response {"$ref": "#/definitions/debug_image_sample_response"}
}
"2.20": ${get_debug_image_sample."2.12"} {
request.properties.navigate_current_metric {
description: If set then subsequent navigation with next_debug_image_sample is done on the debug images for the passed metric only. Otherwise for all the metrics
type: boolean
default: true
}
}
}
next_debug_image_sample {
"2.12": {
@@ -547,6 +652,72 @@ next_debug_image_sample {
response {"$ref": "#/definitions/debug_image_sample_response"}
}
}
get_plot_sample {
"2.20": {
description: "Return the plot per metric and variant for the provided iteration"
request {
type: object
required: [task, metric, variant]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
variant {
description: "Metric variant"
type: string
}
iteration {
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
type: integer
}
refresh {
description: "If set then scroll state will be refreshed to reflect the latest changes in the plots"
type: boolean
}
scroll_id {
type: string
description: "Scroll ID from the previous call to get_plot_sample or empty"
}
navigate_current_metric {
description: If set then subsequent navigation with next_plot_sample is done on the plots for the passed metric only. Otherwise for all the metrics
type: boolean
default: true
}
}
}
response {"$ref": "#/definitions/plot_sample_response"}
}
}
next_plot_sample {
"2.20": {
description: "Get the plot for the next variant for the same iteration or for the next iteration"
request {
type: object
required: [task, scroll_id]
properties {
task {
description: "Task ID"
type: string
}
scroll_id {
type: string
description: "Scroll ID from the previous call to get_plot_sample"
}
navigate_earlier {
type: boolean
description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration.
Otherwise next variant event from the current iteration or first variant event from the next iteration"""
}
}
}
response {"$ref": "#/definitions/plot_sample_response"}
}
}
get_task_metrics{
"2.7": {
description: "For each task, get a list of metrics for which the requested event type was reported"
@@ -1087,7 +1258,7 @@ multi_task_scalar_metrics_iter_histogram {
type: array
items {
type: string
description: "List of task Task IDs"
description: "Task ID"
}
}
samples {
@@ -1112,6 +1283,55 @@ multi_task_scalar_metrics_iter_histogram {
}
}
}
get_task_single_value_metrics {
"2.20" {
description: Get single value metrics for the passed tasks
request {
type: object
required: [tasks]
properties {
tasks {
description: "List of task Task IDs"
type: array
items {
type: string
description: "Task ID"
}
}
}
}
response {
type: object
properties {
tasks {
description: Single value metrics grouped by task
type: array
items {
type: object
properties {
task {
type: string
description: Task ID
}
values {
type: array
items {
type: object
properties {
metric { type: string }
variant { type: string}
value { type: number }
timestamp { type: number }
}
}
}
}
}
}
}
}
}
}
get_task_latest_scalar_values {
"2.1" {
description: "Get the tasks's latest scalar values"
@@ -1304,3 +1524,58 @@ scalar_metrics_iter_raw {
}
}
}
clear_scroll {
"2.18" {
description: "Clear an open Scroll ID"
request {
type: object
required: [
scroll_id
]
properties {
scroll_id {
description: "Scroll ID as returned by previous events service calls"
type: string
}
}
}
response {
type: object
additionalProperties: false
}
}
}
clear_task_log {
"2.19" {
description: Remove old logs from task
request {
type: object
required: [task]
properties {
task {
description: Task ID
type: string
}
allow_locked {
type: boolean
description: Allow deleting events even if the task is locked
default: false
}
threshold_sec {
description: The amount of seconds ago to retain the log records. The older log records will be deleted. If not passed or 0 then all the log records for the task will be deleted
type: integer
}
}
}
response {
type: object
properties {
deleted {
description: The number of deleted log records
type: integer
}
}
}
}
}

View File

@@ -104,6 +104,16 @@ _definitions {
"$ref": "#/definitions/metadata_item"
}
}
stats {
description: "Model statistics"
type: object
properties {
labels_count {
description: Number of the model labels
type: integer
}
}
}
}
}
published_task_item {
@@ -224,6 +234,13 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"2.20": ${get_all_ex."2.15"} {
request.properties.include_stats {
description: "If true, include models statistic in response"
type: boolean
default: false
}
}
}
get_all {
"2.1" {

View File

@@ -102,4 +102,64 @@ get_user_companies {
}
}
}
}
}
get_entities_count {
"2.20": {
description: "Get counts for the company entities according to the passed search criteria"
request {
type: object
properties {
projects {
type: object
additionalProperties: true
description: Search criteria for projects
}
tasks {
type: object
additionalProperties: true
description: Search criteria for experiments
}
models {
type: object
additionalProperties: true
description: Search criteria for models
}
pipelines {
type: object
additionalProperties: true
description: Search criteria for pipelines
}
datasets {
type: object
additionalProperties: true
description: Search criteria for datasets
}
}
}
response {
type: object
properties {
projects {
type: integer
description: The number of projects matching the criteria
}
tasks {
type: integer
description: The number of experiments matching the criteria
}
models {
type: integer
description: The number of models matching the criteria
}
pipelines {
type: integer
description: The number of pipelines matching the criteria
}
datasets {
type: integer
description: The number of datasets matching the criteria
}
}
}
}
}

View File

@@ -25,6 +25,10 @@ _definitions {
description: "Project name"
type: string
}
basename {
description: "Project base name"
type: string
}
description {
description: "Project description"
type: string
@@ -156,6 +160,10 @@ _definitions {
description: "Project name"
type: string
}
basename {
description: "Project base name"
type: string
}
description {
description: "Project description"
type: string
@@ -214,6 +222,28 @@ _definitions {
}
}
}
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
own_models {
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
dataset_stats {
description: Project dataset statistics
type: object
properties {
file_count {
type: integer
description: The number of files stored in the dataset
}
total_size {
type: integer
description: The total dataset size in bytes
}
}
}
}
}
metric_variant_result {
@@ -385,6 +415,10 @@ get_all {
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
type: string
}
basename {
description: "Project base name"
type: string
}
description {
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
type: string
@@ -455,7 +489,14 @@ get_all {
}
}
}
"2.15": ${get_all."2.13"} {
"2.14": ${get_all."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {
@@ -523,20 +564,15 @@ get_all_ex {
}
}
}
response {
properties {
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
own_models {
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
}
}
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.13"} {
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
@@ -568,15 +604,16 @@ get_all_ex {
}
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
description: The filter for selecting entities that participate in statistics calculation. For each task field that you want to filter on pass the list of allowed values. Prepend the value with '-' to exclude
type: object
properties {
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
additionalProperties: true
}
}
"2.20": ${get_all_ex."2.17"} {
request.properties.include_dataset_stats {
description: "If true, include project dataset statistic in response"
type: boolean
default: false
}
}
}

View File

@@ -112,6 +112,12 @@ get_by_id {
}
}
}
"2.20": ${get_by_id."2.4"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
}
// typescript generation hack
get_all_ex {
@@ -140,6 +146,12 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"2.20": ${get_all_ex."2.15"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
}
get_all {
"2.4" {
@@ -226,6 +238,12 @@ get_all {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"2.20": ${get_all."2.15"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
}
get_default {
"2.4" {
@@ -634,6 +652,13 @@ get_queue_metrics : {
}
}
}
"2.20": ${get_queue_metrics."2.4"} {
request.properties.refresh {
type: boolean
default: false
description: If set then the new queue metrics is taken
}
}
}
add_or_update_metadata {
"2.13" {
@@ -700,3 +725,51 @@ delete_metadata {
}
}
}
peek_task {
"2.15" {
description: "Peek the next task from a given queue"
request {
type: object
required: [ queue ]
properties {
queue {
description: "ID of the queue"
type: string
}
}
}
response {
type: object
properties {
task {
description: "Task ID"
type: string
}
}
}
}
}
get_num_entries {
"2.15" {
description: "Get the number of task entries in the given queue"
request {
type: object
required: [ queue ]
properties {
queue {
description: "ID of the queue"
type: string
}
}
}
response {
type: object
properties {
num {
description: "Number of entries"
type: integer
}
}
}
}
}

View File

@@ -685,7 +685,14 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
@@ -822,7 +829,14 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
"2.14": ${get_all."2.1"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {
@@ -1884,7 +1898,7 @@ Fails if the following parameters in the task were not filled:
]
properties {
queue {
description: "Queue id. If not provided, task is added to the default queue."
description: "Queue id. If not provided and no queue name is passed then task is added to the default queue."
type: string
}
}
@@ -1900,6 +1914,12 @@ Fails if the following parameters in the task were not filled:
}
}
}
"2.19": ${enqueue."1.5"} {
request.properties.queue_name {
description: The name of the queue. If the queue does not exist then it is auto-created. Cannot be used together with the queue id
type: string
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {
@@ -1908,7 +1928,7 @@ enqueue_many {
properties {
ids.description: "IDs of the tasks to enqueue"
queue {
description: "Queue id. If not provided, tasks are added to the default queue."
description: "Queue id. If not provided and no queue name is passed then tasks are added to the default queue."
type: string
}
validate_tasks {
@@ -1927,6 +1947,12 @@ enqueue_many {
}
}
}
"2.19": ${enqueue_many."2.13"} {
request.properties.queue_name {
description: The name of the queue. If the queue does not exist then it is auto-created. Cannot be used together with the queue id
type: string
}
}
}
dequeue {
"1.5" {
@@ -2005,6 +2031,18 @@ completed {
} ${_references.status_change_request}
response: ${_definitions.update_response}
}
"2.20": ${completed."2.2"} {
request.properties.publish {
type: boolean
default: false
description: If set and the task is completed successfully then it is published
}
response.properties.published {
description: "Number of tasks published (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
ping {

View File

@@ -139,6 +139,22 @@ get_current_user {
}
}
}
"2.20": ${get_current_user."2.1"} {
response {
properties {
getting_started {
type: object
description: Getting stated info
additionalProperties: true
}
created {
type: string
description: User creation time
format: date-time
}
}
}
}
}
get_all_ex {

View File

@@ -1,152 +1,320 @@
{
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
metrics_category {
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
metrics_category {
type: object
properties {
name {
type: string
description: "Name of the metrics category."
}
metric_keys {
type: array
items { type: string }
description: "The names of the metrics in the category."
}
}
}
aggregation_type {
type: string
enum: [ avg, min, max ]
description: "Metric aggregation type"
}
stat_item {
type: object
properties {
key {
type: string
description: "Name of a metric"
}
category {
"$ref": "#/definitions/aggregation_type"
}
}
}
aggregation_stats {
type: object
properties {
aggregation {
"$ref": "#/definitions/aggregation_type"
}
values {
type: array
description: "List of values corresponding to the dates in metric statistics"
items { type: number }
}
}
}
metric_stats {
type: object
properties {
metric {
type: string
description: "Name of the metric ("cpu_usage", "memory_used" etc.)"
}
variant {
type: string
description: "Name of the metric component. Set only if 'split_by_variant' was set in the request"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no workers activity was recorded are omitted."
items { type: integer }
}
stats {
type: array
description: "Statistics data by type"
items { "$ref": "#/definitions/aggregation_stats" }
}
}
}
worker_stats {
type: object
properties {
worker {
type: string
description: "ID of the worker"
}
metrics {
type: array
description: "List of the metrics statistics for the worker"
items { "$ref": "#/definitions/metric_stats" }
}
}
}
activity_series {
type: object
properties {
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval."
items {type: integer}
}
counts {
type: array
description: "List of worker counts corresponding to the timestamps in the dates list. None values are returned for the dates with no workers."
items {type: integer}
}
}
}
worker {
type: object
properties {
id {
description: "Worker ID"
type: string
}
user {
description: "Associated user (under whose credentials are used by the worker daemon)"
"$ref": "#/definitions/id_name_entry"
}
company {
description: "Associated company"
"$ref": "#/definitions/id_name_entry"
}
ip {
description: "IP of the worker"
type: string
}
register_time {
description: "Registration time"
type: string
format: "date-time"
}
last_activity_time {
description: "Last activity time (even if an error occurred)"
type: string
format: "date-time"
}
last_report_time {
description: "Last successful report time"
type: string
format: "date-time"
}
task {
description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry"
}
project {
description: "Project in which currently executing task resides"
"$ref": "#/definitions/id_name_entry"
}
queue {
description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry"
}
queues {
description: "List of queues on which the worker is listening"
type: array
items { "$ref": "#/definitions/queue_entry" }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
id_name_entry {
type: object
properties {
id {
description: "ID"
type: string
}
name {
description: "Name"
type: string
}
}
}
current_task_entry = ${_definitions.id_name_entry} {
properties {
running_time {
description: "Task running time"
type: integer
}
last_iteration {
description: "Last task iteration"
type: integer
}
}
}
queue_entry = ${_definitions.id_name_entry} {
properties {
next_task {
description: "Next task in the queue"
"$ref": "#/definitions/id_name_entry"
}
num_tasks {
description: "Number of task entries in the queue"
type: integer
}
}
}
machine_stats {
type: object
properties {
cpu_usage {
description: "Average CPU usage per core"
type: array
items { type: number }
}
gpu_usage {
description: "Average GPU usage per GPU card"
type: array
items { type: number }
}
memory_used {
description: "Used memory MBs"
type: integer
}
memory_free {
description: "Free memory MBs"
type: integer
}
gpu_memory_free {
description: "GPU free memory MBs"
type: array
items { type: integer }
}
gpu_memory_used {
description: "GPU used memory MBs"
type: array
items { type: integer }
}
network_tx {
description: "Mbytes per second"
type: integer
}
network_rx {
description: "Mbytes per second"
type: integer
}
disk_free_home {
description: "Mbytes free space of /home drive"
type: integer
}
disk_free_temp {
description: "Mbytes free space of /tmp drive"
type: integer
}
disk_read {
description: "Mbytes read per second"
type: integer
}
disk_write {
description: "Mbytes write per second"
type: integer
}
cpu_temperature {
description: "CPU temperature"
type: array
items { type: number }
}
gpu_temperature {
description: "GPU temperature"
type: array
items { type: number }
}
}
}
}
get_all {
"2.4" {
description: "Returns information on all registered workers."
request {
type: object
properties {
name {
type: string
description: "Name of the metrics category."
}
metric_keys {
type: array
items { type: string }
description: "The names of the metrics in the category."
last_seen {
description: """Filter out workers not active for more than last_seen seconds.
A value or 0 or 'none' will disable the filter."""
type: integer
default: 3600
}
}
}
aggregation_type {
type: string
enum: [ avg, min, max ]
description: "Metric aggregation type"
}
stat_item {
response {
type: object
properties {
key {
type: string
description: "Name of a metric"
}
category {
"$ref": "#/definitions/aggregation_type"
workers {
type: array
items { "$ref": "#/definitions/worker" }
}
}
}
aggregation_stats {
type: object
properties {
aggregation {
"$ref": "#/definitions/aggregation_type"
}
values {
type: array
description: "List of values corresponding to the dates in metric statistics"
items { type: number }
}
}
}
"2.20": ${get_all."2.4"} {
request.properties.tags {
description: The list of allowed worker tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
metric_stats {
type: object
properties {
metric {
type: string
description: "Name of the metric ("cpu_usage", "memory_used" etc.)"
}
variant {
type: string
description: "Name of the metric component. Set only if 'split_by_variant' was set in the request"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no workers activity was recorded are omitted."
items { type: integer }
}
stats {
type: array
description: "Statistics data by type"
items { "$ref": "#/definitions/aggregation_stats" }
}
}
}
worker_stats {
}
}
register {
"2.4" {
description: "Register a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
type: string
description: "ID of the worker"
}
metrics {
type: array
description: "List of the metrics statistics for the worker"
items { "$ref": "#/definitions/metric_stats" }
}
}
}
activity_series {
type: object
properties {
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval."
items {type: integer}
}
counts {
type: array
description: "List of worker counts corresponding to the timestamps in the dates list. None values are returned for the dates with no workers."
items {type: integer}
}
}
}
worker {
type: object
properties {
id {
description: "Worker ID"
description: "Worker id. Must be unique in company."
type: string
}
user {
description: "Associated user (under whose credentials are used by the worker daemon)"
"$ref": "#/definitions/id_name_entry"
}
company {
description: "Associated company"
"$ref": "#/definitions/id_name_entry"
}
ip {
description: "IP of the worker"
type: string
}
register_time {
description: "Registration time"
type: string
format: "date-time"
}
last_activity_time {
description: "Last activity time (even if an error occurred)"
type: string
format: "date-time"
}
last_report_time {
description: "Last successful report time"
type: string
format: "date-time"
}
task {
description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry"
}
project {
description: "Project in which currently executing task resides"
"$ref": "#/definitions/id_name_entry"
}
queue {
description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry"
timeout {
description: "Registration timeout in seconds. If timeout seconds have passed since the worker's last call to register or status_report, the worker is automatically removed from the list of registered workers."
type: integer
default: 600
}
queues {
description: "List of queues on which the worker is listening"
description: "List of queue IDs on which the worker is listening."
type: array
items { "$ref": "#/definitions/queue_entry" }
items { type: string }
}
tags {
description: "User tags for the worker"
@@ -155,348 +323,185 @@
}
}
}
id_name_entry {
response {
type: object
properties {}
}
}
}
unregister {
"2.4" {
description: "Unregister a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
id {
description: "ID"
type: string
}
name {
description: "Name"
worker {
description: "Worker id. Must be unique in company."
type: string
}
}
}
current_task_entry = ${_definitions.id_name_entry} {
properties {
running_time {
description: "Task running time"
type: integer
}
last_iteration {
description: "Last task iteration"
type: integer
}
}
response {
type: object
properties {}
}
queue_entry = ${_definitions.id_name_entry} {
properties {
next_task {
description: "Next task in the queue"
"$ref": "#/definitions/id_name_entry"
}
num_tasks {
description: "Number of task entries in the queue"
type: integer
}
}
}
machine_stats {
}
}
status_report {
"2.4" {
description: "Called periodically by the worker daemon to report machine status"
request {
required: [
worker
timestamp
]
type: object
properties {
cpu_usage {
description: "Average CPU usage per core"
type: array
items { type: number }
worker {
description: "Worker id."
type: string
}
gpu_usage {
description: "Average GPU usage per GPU card"
type: array
items { type: number }
task {
description: "ID of a task currently being run by the worker. If no task is sent, the worker's task field will be cleared."
type: string
}
memory_used {
description: "Used memory MBs"
queue {
description: "ID of the queue from which task was received. If no queue is sent, the worker's queue field will be cleared."
type: string
}
queues {
description: "List of queue IDs on which the worker is listening. If null, the worker's queues list will not be updated."
type: array
items { type: string }
}
timestamp {
description: "UNIX time in seconds since epoch."
type: integer
}
memory_free {
description: "Free memory MBs"
type: integer
machine_stats {
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
gpu_memory_free {
description: "GPU free memory MBs"
tags {
description: "New user tags for the worker"
type: array
items { type: integer }
items: { type: string }
}
gpu_memory_used {
description: "GPU used memory MBs"
}
}
response {
type: object
properties {}
}
}
}
get_metric_keys {
"2.4" {
description: "Returns worker statistics metric keys grouped by categories."
request {
type: object
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: integer }
items { type: string }
}
network_tx {
description: "Mbytes per second"
type: integer
}
network_rx {
description: "Mbytes per second"
type: integer
}
disk_free_home {
description: "Mbytes free space of /home drive"
type: integer
}
disk_free_temp {
description: "Mbytes free space of /tmp drive"
type: integer
}
disk_read {
description: "Mbytes read per second"
type: integer
}
disk_write {
description: "Mbytes write per second"
type: integer
}
cpu_temperature {
description: "CPU temperature"
}
}
response {
type: object
properties {
categories {
type: array
items { type: number }
description: "List of unique metric categories found in the statistics of the requested workers."
items { "$ref": "#/definitions/metrics_category" }
}
gpu_temperature {
description: "GPU temperature"
}
}
}
}
get_stats {
"2.4" {
description: "Returns statistics for the selected workers and time range aggregated by date intervals."
request {
type: object
required: [ from_date, to_date, interval, items ]
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: number }
items { type: string }
}
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
items {
description: "List of metric keys and requested statistics"
type: array
items { "$ref": "#/definitions/stat_item" }
}
split_by_variant {
description: "If true then break statistics by hardware sub types"
type: boolean
default: false
}
}
}
response {
type: object
properties {
workers {
type: array
description: "List of the requested workers with their statistics"
items { "$ref": "#/definitions/worker_stats" }
}
}
}
}
get_all {
"2.4" {
description: "Returns information on all registered workers."
request {
type: object
properties {
last_seen {
description: """Filter out workers not active for more than last_seen seconds.
A value or 0 or 'none' will disable the filter."""
type: integer
default: 3600
}
}
get_activity_report {
"2.4" {
description: "Returns count of active company workers in the selected time range."
request {
type: object
required: [ from_date, to_date, interval ]
properties {
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
}
response {
type: object
properties {
workers {
type: array
items { "$ref": "#/definitions/worker" }
}
}
response {
type: object
properties {
total {
description: "Activity series that include all the workers that sent reports in the given time interval."
"$ref": "#/definitions/activity_series"
}
active {
description: "Activity series that include only workers that worked on a task in the given time interval."
"$ref": "#/definitions/activity_series"
}
}
}
}
register {
"2.4" {
description: "Register a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
description: "Worker id. Must be unique in company."
type: string
}
timeout {
description: "Registration timeout in seconds. If timeout seconds have passed since the worker's last call to register or status_report, the worker is automatically removed from the list of registered workers."
type: integer
default: 600
}
queues {
description: "List of queue IDs on which the worker is listening."
type: array
items { type: string }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
response {
type: object
properties {}
}
}
}
unregister {
"2.4" {
description: "Unregister a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
description: "Worker id. Must be unique in company."
type: string
}
}
}
response {
type: object
properties {}
}
}
}
status_report {
"2.4" {
description: "Called periodically by the worker daemon to report machine status"
request {
required: [
worker
timestamp
]
type: object
properties {
worker {
description: "Worker id."
type: string
}
task {
description: "ID of a task currently being run by the worker. If no task is sent, the worker's task field will be cleared."
type: string
}
queue {
description: "ID of the queue from which task was received. If no queue is sent, the worker's queue field will be cleared."
type: string
}
queues {
description: "List of queue IDs on which the worker is listening. If null, the worker's queues list will not be updated."
type: array
items { type: string }
}
timestamp {
description: "UNIX time in seconds since epoch."
type: integer
}
machine_stats {
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
tags {
description: "New user tags for the worker"
type: array
items: { type: string }
}
}
}
response {
type: object
properties {}
}
}
}
get_metric_keys {
"2.4" {
description: "Returns worker statistics metric keys grouped by categories."
request {
type: object
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
categories {
type: array
description: "List of unique metric categories found in the statistics of the requested workers."
items { "$ref": "#/definitions/metrics_category" }
}
}
}
}
}
get_stats {
"2.4" {
description: "Returns statistics for the selected workers and time range aggregated by date intervals."
request {
type: object
required: [ from_date, to_date, interval, items ]
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: string }
}
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
items {
description: "List of metric keys and requested statistics"
type: array
items { "$ref": "#/definitions/stat_item" }
}
split_by_variant {
description: "If true then break statistics by hardware sub types"
type: boolean
default: false
}
}
}
response {
type: object
properties {
workers {
type: array
description: "List of the requested workers with their statistics"
items { "$ref": "#/definitions/worker_stats" }
}
}
}
}
}
get_activity_report {
"2.4" {
description: "Returns count of active company workers in the selected time range."
request {
type: object
required: [ from_date, to_date, interval ]
properties {
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
}
}
response {
type: object
properties {
total {
description: "Activity series that include all the workers that sent reports in the given time interval."
"$ref": "#/definitions/activity_series"
}
active {
description: "Activity series that include only workers that worked on a task in the given time interval."
"$ref": "#/definitions/activity_series"
}
}
}
}
}
}
}

View File

@@ -313,6 +313,7 @@ class APICall(DataContainer):
_redacted_headers = {
HEADER_AUTHORIZATION: " ",
"Cookie": "=",
"X-Jwt-Payload": "",
}
""" Headers whose value should be redacted. Maps header name to partition char """
@@ -692,6 +693,10 @@ class APICall(DataContainer):
# this will allow us to debug authorization errors).
for header, sep in self._redacted_headers.items():
if header in headers:
prefix, _, redact = headers[header].partition(sep)
if sep:
prefix, _, redact = headers[header].partition(sep)
else:
prefix = sep = ""
redact = headers[header]
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
return headers

View File

@@ -2,6 +2,8 @@ import jwt
from datetime import datetime, timedelta
from jwt.algorithms import get_default_algorithms
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model.auth import Role
@@ -9,22 +11,24 @@ from apiserver.database.model.auth import Role
from .auth_type import AuthType
from .payload import Payload
token_secret = config.get('secure.auth.token_secret')
token_secret = config.get("secure.auth.token_secret")
log = config.logger(__file__)
class Token(Payload):
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
def __init__(self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_):
def __init__(
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
):
super(Token, self).__init__(
AuthType.bearer_token, identity=identity, entities=entities)
AuthType.bearer_token, identity=identity, entities=entities
)
self.exp = exp
self.iat = iat
self.nbf = nbf
self._env = env or config.get('env', '<unknown>')
self._env = env or config.get("env", "<unknown>")
@property
def env(self):
@@ -65,7 +69,15 @@ class Token(Payload):
@classmethod
def decode(cls, encoded_token, verify=True):
return jwt.decode(encoded_token, token_secret, verify=verify)
options = (
{"verify_signature": False, "verify_exp": True} if not verify else None
)
return jwt.decode(
encoded_token,
token_secret,
algorithms=get_default_algorithms(),
options=options,
)
@classmethod
def from_encoded_token(cls, encoded_token, verify=True):
@@ -74,23 +86,24 @@ class Token(Payload):
token = Token.from_dict(decoded)
assert isinstance(token, Token)
if not token.identity:
raise errors.unauthorized.InvalidToken('token missing identity')
raise errors.unauthorized.InvalidToken("token missing identity")
return token
except Exception as e:
raise errors.unauthorized.InvalidToken('failed parsing token, %s' % e.args[0])
raise errors.unauthorized.InvalidToken(
"failed parsing token, %s" % e.args[0]
)
@classmethod
def create_encoded_token(cls, identity, expiration_sec=None, entities=None, **extra_payload):
def create_encoded_token(
cls, identity, expiration_sec=None, entities=None, **extra_payload
):
if identity.role not in (Role.system,):
# limit expiration time for all roles but an internal service
expiration_sec = expiration_sec or cls.default_expiration_sec
now = datetime.utcnow()
token = cls(
identity=identity,
entities=entities,
iat=now)
token = cls(identity=identity, entities=entities, iat=now)
if expiration_sec:
# add 'expiration' claim

View File

@@ -8,6 +8,7 @@ import jsonmodels.models
from apiserver.apierrors import APIError, errors
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall
from .auth import Identity
@@ -38,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.17")
_max_version = PartialVersion("2.20")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -283,7 +284,8 @@ class ServiceRepo(object):
# In case call does not require authorization, parsing the identity.company might raise an exception
company = cls._get_company(call, endpoint)
ret = endpoint.func(call, company, call.data_model)
with translate_errors_context():
ret = endpoint.func(call, company, call.data_model)
# allow endpoints to return dict or model (instead of setting them explicitly on the call)
if ret is not None:

View File

@@ -14,6 +14,7 @@ from apiserver.apimodels.auth import (
RevokeCredentialsRequest,
EditUserReq,
CreateCredentialsRequest,
EditCredentialsRequest,
)
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.auth import AuthBLL
@@ -122,6 +123,27 @@ def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@endpoint("auth.edit_credentials")
def edit_credentials(call: APICall, company_id: str, request: EditCredentialsRequest):
identity = call.identity
access_key = request.access_key
updated = User.objects(
id=identity.user,
company=company_id,
credentials__match={"key": access_key},
).update_one(set__credentials__S__label=request.label)
if not updated:
raise errors.bad_request.InvalidAccessKey(
"invalid user or invalid access key",
user=identity.user,
access_key=access_key,
company=company_id,
)
call.result.data = {"updated": updated}
@endpoint(
"auth.revoke_credentials",
request_data_model=RevokeCredentialsRequest,

View File

@@ -12,19 +12,22 @@ from apiserver.apierrors import errors
from apiserver.apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
DebugImagesRequest,
DebugImageResponse,
MetricEventsRequest,
MetricEventsResponse,
MetricEvents,
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
LogOrderEnum,
GetDebugImageSampleRequest,
NextDebugImageSampleRequest,
GetHistorySampleRequest,
NextHistorySampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
TaskEventsRequest,
ScalarMetricsIterRawRequest,
ClearScrollRequest,
ClearTaskLogRequest,
SingleValueMetricsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
@@ -448,6 +451,30 @@ def multi_task_scalar_metrics_iter_histogram(
)
@endpoint("events.get_task_single_value_metrics")
def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest
):
task_ids = call.data["tasks"]
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids)
call.result.data = dict(
tasks=[{"task": task, "values": values} for task, values in res.items()]
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
@@ -611,6 +638,56 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
)
@endpoint(
"events.plots",
request_data_model=MetricEventsRequest,
response_data_model=MetricEventsResponse,
)
def task_plots(call, company_id, request: MetricEventsRequest):
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
tasks = task_bll.assert_exists(
company_id,
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.plots_iterator.get_task_events(
company_id=next(iter(companies)),
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
refresh=request.refresh,
state_id=request.scroll_id,
)
call.result.data_model = MetricEventsResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, iterations) in result.metric_events
],
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
@@ -680,10 +757,10 @@ def get_debug_images_v1_8(call, company_id, _):
@endpoint(
"events.debug_images",
min_version="2.7",
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
request_data_model=MetricEventsRequest,
response_data_model=MetricEventsResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
def get_debug_images(call, company_id, request: MetricEventsRequest):
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task][tm.metric] = tm.variants
@@ -713,7 +790,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
state_id=request.scroll_id,
)
call.result.data_model = DebugImageResponse(
call.result.data_model = MetricEventsResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
@@ -731,13 +808,13 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
@endpoint(
"events.get_debug_image_sample",
min_version="2.12",
request_data_model=GetDebugImageSampleRequest,
request_data_model=GetHistorySampleRequest,
)
def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest):
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
res = event_bll.debug_sample_history.get_debug_image_for_variant(
res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task.company,
task=request.task,
metric=request.metric,
@@ -745,6 +822,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id,
navigate_current_metric=request.navigate_current_metric,
)
call.result.data = attr.asdict(res, recurse=False)
@@ -752,13 +830,49 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
@endpoint(
"events.next_debug_image_sample",
min_version="2.12",
request_data_model=NextDebugImageSampleRequest,
request_data_model=NextHistorySampleRequest,
)
def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest):
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
res = event_bll.debug_sample_history.get_next_debug_image(
res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task.company,
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier,
)
call.result.data = attr.asdict(res, recurse=False)
@endpoint(
"events.get_plot_sample", request_data_model=GetHistorySampleRequest,
)
def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
res = event_bll.plot_sample_history.get_sample_for_variant(
company_id=task.company,
task=request.task,
metric=request.metric,
variant=request.variant,
iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id,
navigate_current_metric=request.navigate_current_metric,
)
call.result.data = attr.asdict(res, recurse=False)
@endpoint(
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
)
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
res = event_bll.plot_sample_history.get_next_sample(
company_id=task.company,
task=request.task,
state_id=request.scroll_id,
@@ -768,14 +882,14 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id,
task_ids=request.tasks,
allow_public=True,
only=("company", "company_origin"),
)[0]
res = event_bll.metrics.get_tasks_metrics(
res = event_bll.metrics.get_task_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
)
call.result.data = {
@@ -796,6 +910,21 @@ def delete_for_task(call, company_id, req_model):
)
@endpoint("events.clear_task_log")
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
task_bll.assert_exists(company_id, task_id, return_tasks=False)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,
task_id=task_id,
allow_locked=request.allow_locked,
threshold_sec=request.threshold_sec,
)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
key = itemgetter("metric", "variant", "task", "iter")
@@ -936,3 +1065,9 @@ def scalar_metrics_iter_raw(
scroll_id=scroll.get_scroll_id(),
variants=variants,
)
@endpoint("events.clear_scroll", min_version="2.18")
def clear_scroll(_, __, request: ClearScrollRequest):
if request.scroll_id:
event_bll.clear_scroll(request.scroll_id)

View File

@@ -20,6 +20,7 @@ from apiserver.apimodels.models import (
AddOrUpdateMetadataRequest,
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
ModelsGetRequest,
)
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
@@ -117,8 +118,8 @@ def _process_include_subprojects(call_data: dict):
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
_process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
@@ -132,6 +133,20 @@ def get_all_ex(call: APICall, company_id, _):
)
conform_output_tags(call, models)
unescape_metadata(call, models)
if not request.include_stats:
call.result.data = {"models": models, **ret_params}
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
for model in models:
model["stats"] = stats.get(model["id"])
call.result.data = {"models": models, **ret_params}

View File

@@ -1,9 +1,15 @@
from collections import defaultdict
from operator import itemgetter
from typing import Mapping, Type
from apiserver.apimodels.organization import TagsRequest
from mongoengine import Q
from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User
from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
@@ -41,3 +47,29 @@ def get_user_companies(call: APICall, company_id: str, _):
}
]
}
@endpoint("organization.get_entities_count", request_data_model=EntitiesCountRequest)
def get_entities_count(call: APICall, company, _):
entity_classes: Mapping[str, Type[AttributedDocument]] = {
"projects": Project,
"tasks": Task,
"models": Model,
"pipelines": Project,
"datasets": Project,
}
ret = {}
for field, entity_cls in entity_classes.items():
data = call.data.get(field)
if data is None:
continue
query = Q()
if entity_cls in (Project, Task) and not data.get("search_hidden"):
query &= Q(system_tags__ne=EntityVisibility.hidden.value)
ret[field] = entity_cls.get_count(
company=company, query_dict=data, query=query, allow_public=True,
)
call.result.data = ret

View File

@@ -26,6 +26,7 @@ from apiserver.bll.project.project_cleanup import (
validate_project_delete,
)
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.utils import (
parse_from_call,
@@ -73,6 +74,16 @@ def get_by_id(call):
call.result.data = {"project": project_dict}
def _hidden_query(search_hidden: bool, ids: Sequence) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
"""
if search_hidden or ids:
return Q()
return Q(system_tags__ne=EntityVisibility.hidden.value)
def _adjust_search_parameters(data: dict, shallow_search: bool):
"""
1. Make sure that there is no external query on path
@@ -91,30 +102,31 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data)
allow_public = not request.non_public
data = call.data
conform_tag_fields(call, data)
allow_public = not request.non_public
requested_ids = data.get("id")
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
)
with TimingContext("mongo", "projects_get_all"):
data = call.data
user_active_project_ids = None
if request.active_users:
ids = project_bll.get_projects_with_active_user(
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
company=company_id,
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
)
if not ids:
call.result.data = {"projects": []}
return
return {"projects": []}
data["id"] = ids
_adjust_search_parameters(data, shallow_search=request.shallow_search)
ret_params = {}
projects = Project.get_many_with_join(
projects: Sequence[dict] = Project.get_many_with_join(
company=company_id,
query_dict=data,
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
allow_public=allow_public,
ret_params=ret_params,
)
@@ -128,47 +140,62 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
company=company_id,
project_ids=list(existing_requested_ids),
filter_=request.include_stats_filter,
users=request.active_users,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects, **ret_params}
return
if request.include_stats:
project_ids = {project["id"] for project in projects}
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=request.search_hidden,
filter_=request.include_stats_filter,
users=request.active_users,
user_active_project_ids=user_active_project_ids,
)
project_ids = {project["id"] for project in projects}
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
filter_=request.include_stats_filter,
)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
if request.include_dataset_stats:
project_ids = {project["id"] for project in projects}
dataset_stats = project_bll.get_dataset_stats(
company=company_id,
project_ids=list(project_ids),
users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
call.result.data = {"projects": projects, **ret_params}
@endpoint("projects.get_all")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
data = call.data
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
conform_tag_fields(call, data)
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
)
with TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
),
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects, **ret_params}

View File

@@ -14,6 +14,7 @@ from apiserver.apimodels.queues import (
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
GetNextTaskRequest,
GetByIdRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
@@ -33,9 +34,11 @@ worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=QueueRequest)
def get_by_id(call: APICall, company_id, req_model: QueueRequest):
queue = queue_bll.get_by_id(company_id, req_model.queue)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
@@ -55,7 +58,10 @@ def get_all_ex(call: APICall):
Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
company_id=call.identity.company,
query_dict=call.data,
max_task_entries=call.data.pop("max_task_entries", None),
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
@@ -68,7 +74,10 @@ def get_all(call: APICall):
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_all(
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
company_id=call.identity.company,
query_dict=call.data,
max_task_entries=call.data.pop("max_task_entries", None),
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
@@ -127,9 +136,7 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=req_model.queue
)
entry = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
if entry:
data = {"entry": entry.to_proper_dict()}
if req_model.get_task_info:
@@ -224,14 +231,15 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
response_data_model=GetMetricsResponse,
)
def get_queue_metrics(
call: APICall, company_id, req_model: GetMetricsRequest
call: APICall, company_id, request: GetMetricsRequest
) -> GetMetricsResponse:
ret = queue_bll.metrics.get_queue_metrics(
company_id=company_id,
from_date=req_model.from_date,
to_date=req_model.to_date,
interval=req_model.interval,
queue_ids=req_model.queue_ids,
from_date=request.from_date,
to_date=request.to_date,
interval=request.interval,
queue_ids=request.queue_ids,
refresh=request.refresh,
)
queue_dicts = {
@@ -273,3 +281,17 @@ def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataReque
queue_id = request.queue
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
@endpoint("queues.peek_task", min_version="2.15")
def peek_task(call: APICall, company_id: str, request: QueueRequest):
queue_id = request.queue
queue = queue_bll.get_by_id(
company_id=company_id, queue_id=queue_id, max_task_entries=1
)
return {"task": queue.entries[0].task if queue.entries else None}
@endpoint("queues.get_num_entries", min_version="2.15")
def get_num_entries(call: APICall, company_id: str, request: QueueRequest):
return {"num": queue_bll.count_entries(company=company_id, queue_id=request.queue)}

View File

@@ -62,6 +62,8 @@ from apiserver.apimodels.tasks import (
DequeueManyResponse,
ResetManyResponse,
ResetBatchItem,
CompletedRequest,
CompletedResponse,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@@ -94,10 +96,12 @@ from apiserver.bll.task.task_operations import (
delete_task,
publish_task,
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
@@ -117,6 +121,7 @@ from apiserver.services.utils import (
unescape_dict_field,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.partial_version import PartialVersion
task_fields = set(Task.get_fields())
@@ -213,6 +218,16 @@ def _process_include_subprojects(call_data: dict):
call_data["project"] = project_ids_with_children(project_ids)
def _hidden_query(data: dict) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
"""
if data.get("search_hidden") or data.get("id"):
return Q()
return Q(system_tags__ne=EntityVisibility.hidden.value)
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
@@ -225,6 +240,7 @@ def get_all_ex(call: APICall, company_id, _):
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
@@ -259,6 +275,7 @@ def get_all(call: APICall, company_id, _):
company=company_id,
parameters=call_data,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
@@ -848,6 +865,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
queue_name=request.queue_name,
force=request.force,
)
call.result.data_model = EnqueueResponse(queued=queued, **res)
@@ -866,6 +884,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
queue_name=request.queue_name,
validate=request.validate_tasks,
),
ids=request.ids,
@@ -1060,6 +1079,8 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_reason=request.status_reason,
)
if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@@ -1081,6 +1102,10 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
)
if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))
@@ -1139,11 +1164,11 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
@endpoint(
"tasks.completed",
min_version="2.2",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
request_data_model=CompletedRequest,
response_data_model=CompletedResponse,
)
def completed(call: APICall, company_id, request: PublishRequest):
call.result.data_model = UpdateResponse(
def completed(call: APICall, company_id, request: CompletedRequest):
res = CompletedResponse(
**set_task_status_from_call(
request,
company_id,
@@ -1152,6 +1177,22 @@ def completed(call: APICall, company_id, request: PublishRequest):
)
)
if res.updated and request.publish:
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
status_message=request.status_message,
)
res.published = publish_res.get("updated")
new_status = nested_get(publish_res, ("fields", "status"))
if new_status:
res.fields["status"] = new_status
call.result.data_model = res
@endpoint("tasks.ping", request_data_model=PingRequest)
def ping(_, company_id, request: PingRequest):

View File

@@ -12,7 +12,7 @@ from apiserver.bll.project import ProjectBLL
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import Role
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.company import Company
from apiserver.database.model.user import User
from apiserver.database.utils import parse_from_call
@@ -95,26 +95,36 @@ def get_all(call: APICall, company_id, _):
@endpoint("users.get_current_user")
def get_current_user(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
projection = (
{"company.name"}
.union(User.get_fields())
.difference(User.get_exclude_fields())
)
res = User.get_many_with_join(
query=Q(id=call.identity.user),
company=company_id,
override_projection=projection,
)
user_id = call.identity.user
if not res:
raise errors.bad_request.InvalidUser("failed loading user")
projection = (
{"company.name"}
.union(User.get_fields())
.difference(User.get_exclude_fields())
)
res = User.get_many_with_join(
query=Q(id=user_id),
company=company_id,
override_projection=projection,
)
user = res[0]
user["role"] = call.identity.role
if not res:
raise errors.bad_request.InvalidUser("failed loading user")
resp = {"user": user}
call.result.data = resp
user = res[0]
user["role"] = call.identity.role
auth_user: AuthUser = AuthUser.objects(id=user_id, company=company_id).first()
if not auth_user:
raise errors.bad_request.InvalidUser("failed loading user")
user["created"] = auth_user.created
resp = {
"user": user,
"getting_started": config.get("apiserver.getting_started_info", None),
}
call.result.data = resp
create_fields = {

View File

@@ -41,7 +41,9 @@ worker_bll = WorkerBLL()
)
def get_all(call: APICall, company_id: str, request: GetAllRequest):
call.result.data_model = GetAllResponse(
workers=worker_bll.get_all_with_projection(company_id, request.last_seen)
workers=worker_bll.get_all_with_projection(
company_id, request.last_seen, tags=request.tags
)
)
@@ -72,7 +74,9 @@ def unregister(call: APICall, company_id, req_model: WorkerRequest):
worker_bll.unregister_worker(company_id, call.identity.user, req_model.worker)
@endpoint("workers.status_report", min_version="2.4", request_data_model=StatusReportRequest)
@endpoint(
"workers.status_report", min_version="2.4", request_data_model=StatusReportRequest
)
def status_report(call: APICall, company_id, request: StatusReportRequest):
worker_bll.status_report(
company_id=company_id,

View File

@@ -1,4 +0,0 @@
{
api_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
}

View File

@@ -20,7 +20,7 @@ class TestBatchOperations(TestService):
ids = [*tasks, missing_id]
# enqueue
res = self.api.tasks.enqueue_many(ids=ids)
res = self.api.tasks.enqueue_many(ids=ids, queue_name="test batch")
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks

View File

@@ -77,7 +77,7 @@ class TestModelsService(TestService):
def test_publish_output_model_no_task(self):
model_id = self.create_temp(
service="models", name='test', uri='file:///a', labels={}, ready=False
service="models", name="test", uri="file:///a", labels={}, ready=False
)
self._assert_model_ready(model_id, False)
@@ -109,7 +109,7 @@ class TestModelsService(TestService):
def test_publish_task_no_output_model(self):
task_id = self.create_temp(
service="tasks", type='testing', name='server-test', input=dict(view={})
service="tasks", type="testing", name="server-test", input=dict(view={})
)
self.api.tasks.started(task=task_id)
self.api.tasks.stopped(task=task_id)
@@ -118,31 +118,48 @@ class TestModelsService(TestService):
assert res.updated == 1 # model updated
self._assert_task_status(task_id, PUBLISHED)
def test_get_models_stats(self):
model1 = self._create_model(labels={"hello": 1, "world": 2})
model2 = self._create_model(labels={"foo": 1})
model3 = self._create_model()
# no stats
res = self.api.models.get_all_ex(id=[model1, model2, model3]).models
self.assertEqual(len(res), 3)
self.assertTrue(all("stats" not in m for m in res))
# stats
res = self.api.models.get_all_ex(
id=[model1, model2, model3], include_stats=True
).models
self.assertEqual(len(res), 3)
stats = {m.id: m.stats.labels_count for m in res}
self.assertEqual(stats[model1], 2)
self.assertEqual(stats[model2], 1)
self.assertEqual(stats[model3], 0)
def test_update_model_iteration_with_task(self):
task_id = self._create_task()
model_id = self._create_model()
self.api.models.update(model=model_id, task=task_id, iteration=1000, labels={"foo": 1})
self.api.models.update(
model=model_id, task=task_id, iteration=1000, labels={"foo": 1}
)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
self.api.models.update(model=model_id, task=task_id, iteration=500)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
def test_update_model_for_task_iteration(self):
task_id = self._create_task()
res = self.api.models.update_for_task(
task=task_id,
name="test model",
uri="file:///b",
iteration=999,
task=task_id, name="test model", uri="file:///b", iteration=999,
)
model_id = res.id
@@ -150,22 +167,19 @@ class TestModelsService(TestService):
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
999
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 999
)
self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
def test_get_frameworks(self):
@@ -238,8 +252,8 @@ class TestModelsService(TestService):
return self.create_temp(
service="models",
delete_params=dict(can_fail=True, force=True),
name=kwargs.pop("name", 'test'),
uri=kwargs.pop("name", 'file:///a'),
name=kwargs.pop("name", "test"),
uri=kwargs.pop("name", "file:///a"),
labels=kwargs.pop("labels", {}),
**kwargs,
)
@@ -247,8 +261,8 @@ class TestModelsService(TestService):
def _create_task(self, **kwargs):
task_id = self.create_temp(
service="tasks",
type=kwargs.pop("type", 'testing'),
name=kwargs.pop("name", 'server-test'),
type=kwargs.pop("type", "testing"),
name=kwargs.pop("name", "server-test"),
input=kwargs.pop("input", dict(view={})),
**kwargs,
)
@@ -257,21 +271,18 @@ class TestModelsService(TestService):
def _create_task_and_model(self):
execution_model_id = self.create_temp(
service="models",
name='test',
uri='file:///a',
labels={}
service="models", name="test", uri="file:///a", labels={}
)
task_id = self.create_temp(
service="tasks",
type='testing',
name='server-test',
type="testing",
name="server-test",
input=dict(view={}),
execution=dict(model=execution_model_id)
execution=dict(model=execution_model_id),
)
self.api.tasks.started(task=task_id)
output_model_id = self.api.models.update_for_task(
task=task_id, uri='file:///b'
task=task_id, uri="file:///b"
)["id"]
return task_id, output_model_id

View File

@@ -1,4 +1,4 @@
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apierrors.errors.bad_request import InvalidProjectId, ExpectedUniqueData
from apiserver.apierrors.errors.forbidden import NoWritePermission
from apiserver.config_repo import config
from apiserver.tests.automated import TestService
@@ -32,3 +32,12 @@ class TestProjectsEdit(TestService):
res = self.api.projects.get_all(id=[p1])
self.assertEqual([p.id for p in res.projects], [p1])
self.api.projects.update(project=p1, name="Test public change 2")
def test_project_name_uniqueness(self):
name1 = "Test name1"
p1 = self.create_temp("projects", name=name1, description="test")
with self.api.raises(ExpectedUniqueData):
p2 = self.create_temp("projects", name=name1, description="test")
p2 = self.create_temp("projects", name="Test name2", description="test")
with self.api.raises(ExpectedUniqueData):
self.api.projects.update(project=p2, name=name1)

View File

@@ -9,9 +9,6 @@ from apiserver.tests.automated import TestService, utc_now_tz_aware
class TestQueues(TestService):
def setUp(self, version="2.4"):
super().setUp(version=version)
def test_default_queue(self):
res = self.api.queues.get_default()
self.assertIsNotNone(res.id)
@@ -63,6 +60,34 @@ class TestQueues(TestService):
self.assertQueueTasks(res.queue, [task])
self.assertTaskTags(task, system_tags=[])
def test_max_queue_entries(self):
queue = self._temp_queue("TestTempQueue")
tasks = [
self._create_temp_queued_task(t, queue)["id"]
for t in ("temp task1", "temp task2", "temp task3")
]
num = self.api.queues.get_num_entries(queue=queue).num
self.assertEqual(num, 3)
task_id = self.api.queues.peek_task(queue=queue).task
self.assertEqual(task_id, tasks[0])
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
res = self.api.queues.get_all(id=[queue]).queues[0]
self.assertQueueTasks(res, tasks)
res = self.api.queues.get_all(id=[queue], max_task_entries=2).queues[0]
self.assertQueueTasks(res, tasks[:2])
res = self.api.queues.get_all_ex(id=[queue]).queues[0]
self.assertEqual([e.task.id for e in res.entries], tasks)
res = self.api.queues.get_all_ex(id=[queue], max_task_entries=2).queues[0]
self.assertEqual([e.task.id for e in res.entries], tasks[:2])
def test_move_task(self):
queue = self._temp_queue("TestTempQueue")
tasks = [

View File

@@ -12,6 +12,19 @@ from apiserver.tests.automated import TestService
class TestSubProjects(TestService):
def test_dataset_stats(self):
project = self._temp_project(name="Dataset test", system_tags=["dataset"])
res = self.api.organization.get_entities_count(datasets={"system_tags": ["dataset"]})
self.assertEqual(res.datasets, 1)
task = self._temp_task(project=project)
data = self.api.projects.get_all_ex(id=[project], include_dataset_stats=True).projects[0]
self.assertIsNone(data.dataset_stats)
self.api.tasks.edit(task=task, runtime={"ds_file_count": 2, "ds_total_size": 1000})
data = self.api.projects.get_all_ex(id=[project], include_dataset_stats=True).projects[0]
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
@@ -20,11 +33,12 @@ class TestSubProjects(TestService):
base_url=f"http://localhost:8008/v2.13",
)
child = self._temp_project(name="Aggregation/Pr1", client=user2_client)
basename = "Pr1"
child = self._temp_project(name=f"Aggregation/{basename}", client=user2_client)
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
self.assertEqual(child_project.parent.id, project)
self.assertEqual(child_project.basename, basename)
user = self.api.users.get_current_user().user.id
# test aggregations on project with empty subprojects
@@ -42,12 +56,17 @@ class TestSubProjects(TestService):
# test aggregations with non-empty subprojects
task1 = self._temp_task(project=child)
self._temp_task(project=child, parent=task1)
user2_task = self._temp_task(project=child, client=user2_client)
framework = "Test framework"
self._temp_model(project=child, framework=framework)
res = self.api.users.get_all_ex(active_in_projects=[project])
self._assert_ids(res.users, [user])
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
res = self.api.projects.get_all_ex(id=[project], include_stats=True)
self._assert_ids(res.projects, [project])
self.assertEqual(res.projects[0].stats.active.total_tasks, 3)
res = self.api.projects.get_all_ex(id=[project], active_users=[user], include_stats=True)
self._assert_ids(res.projects, [project])
self.assertEqual(res.projects[0].stats.active.total_tasks, 2)
res = self.api.projects.get_task_parents(projects=[project])
self._assert_ids(res.parents, [task1])
res = self.api.models.get_frameworks(projects=[project])
@@ -70,8 +89,11 @@ class TestSubProjects(TestService):
# update
with self.api.raises(errors.bad_request.CannotUpdateProjectLocation):
self.api.projects.update(project=project1, name="Root2/Pr2")
res = self.api.projects.update(project=project1, name="Root1/Pr2")
new_basename = "Pr2"
res = self.api.projects.update(project=project1, name=f"Root1/{new_basename}")
self.assertEqual(res.updated, 1)
res = self.api.projects.get_by_id(project=project1)
self.assertEqual(res.project.basename, new_basename)
res = self.api.projects.get_by_id(project=project1_child)
self.assertEqual(res.project.name, "Root1/Pr2/Pr2")
@@ -80,6 +102,7 @@ class TestSubProjects(TestService):
self.assertEqual(res.moved, 2)
res = self.api.projects.get_by_id(project=project1_child)
self.assertEqual(res.project.name, "Root2/Pr2/Pr2")
self.assertEqual(res.project.basename, "Pr2")
# merge
project_with_task, (active, archived) = self._temp_project_with_tasks(
@@ -102,6 +125,7 @@ class TestSubProjects(TestService):
self.assertEqual(res.moved_projects, 1)
res = self.api.projects.get_by_id(project=project_with_task)
self.assertEqual(res.project.name, "Root2/Pr2/Pr4")
self.assertEqual(res.project.basename, "Pr4")
with self.api.raises(errors.bad_request.InvalidProjectId):
self.api.projects.get_by_id(project=merge_source)
@@ -156,6 +180,11 @@ class TestSubProjects(TestService):
self.assertEqual([p.id for p in res], [project1])
res = self.api.projects.get_all_ex(name="project1", parent=[project1]).projects
self.assertEqual([p.id for p in res], [project2])
# basename search
res = self.api.projects.get_all_ex(
basename="project2", shallow_search=True
).projects
self.assertEqual(res, [])
# global search finds all or below the specified level
res = self.api.projects.get_all_ex(name="project1").projects
@@ -163,7 +192,9 @@ class TestSubProjects(TestService):
project4 = self._temp_project(name="project1/project2/project1")
res = self.api.projects.get_all_ex(name="project1", parent=[project2]).projects
self.assertEqual([p.id for p in res], [project4])
# basename search
res = self.api.projects.get_all_ex(basename="project2").projects
self.assertEqual([p.id for p in res], [project2])
self.api.projects.delete(project=project1, force=True)
def test_get_all_with_check_own_contents(self):
@@ -249,13 +280,14 @@ class TestSubProjects(TestService):
**kwargs,
)
def _temp_task(self, **kwargs):
def _temp_task(self, client=None, **kwargs):
return self.create_temp(
"tasks",
delete_params=self.delete_params,
type="testing",
name=db_id(),
input=dict(view=dict()),
client=client,
**kwargs,
)

View File

@@ -67,6 +67,86 @@ class TestTaskEvents(TestService):
),
)
def test_task_single_value_metrics(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 10
task = self._temp_task()
special_iteration = -(2 ** 31)
events = [
{
**self._create_task_event(
"training_stats_scalar", task, iteration or special_iteration
),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
# special iteration is present in the events retrieval
metric_param = {"metric": metric, "variants": [variant]}
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=100, metric=metric_param, count_total=True
)
self.assertEqual(res.returned, iter_count)
self.assertEqual(res.total, iter_count)
self.assertEqual(
res.variants[variant]["iter"],
[x or special_iteration for x in range(iter_count)],
)
self.assertEqual(
res.variants[variant]["y"], list(range(iter_count))
)
# but not in the histogram
data = self.api.events.scalar_metrics_iter_histogram(task=task)
self.assertEqual(data[metric][variant]["x"], list(range(1, iter_count)))
# new api
res = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks
self.assertEqual(len(res), 1)
data = res[0]
self.assertEqual(data.task, task)
self.assertEqual(len(data["values"]), 1)
value = data["values"][0]
self.assertEqual(value.metric, metric)
self.assertEqual(value.variant, variant)
self.assertEqual(value.value, 0)
# update is working
task_data = self.api.tasks.get_by_id(task=task).task
last_metrics = first(first(task_data.last_metrics.values()).values())
self.assertEqual(last_metrics.value, iter_count - 1)
new_value = 1000
new_event = {
**self._create_task_event("training_stats_scalar", task, special_iteration),
"metric": metric,
"variant": variant,
"value": new_value,
}
self.send(new_event)
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=100, metric=metric_param, count_total=True
)
self.assertEqual(
res.variants[variant]["y"],
[y or new_value for y in range(iter_count)],
)
task_data = self.api.tasks.get_by_id(task=task).task
last_metrics = first(first(task_data.last_metrics.values()).values())
self.assertEqual(last_metrics.value, new_value)
data = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks[0]
self.assertEqual(data.task, task)
self.assertEqual(len(data["values"]), 1)
value = data["values"][0]
self.assertEqual(value.value, new_value)
def test_last_scalar_metrics(self):
metric = "Metric1"
variant = "Variant1"

View File

@@ -0,0 +1,321 @@
from functools import partial
from typing import Sequence, Mapping, Optional
from apiserver.es_factory import es_factory
from apiserver.tests.automated import TestService
class TestTaskPlots(TestService):
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
)
return self.create_temp("tasks", **task_input)
@staticmethod
def _create_task_event(task, iteration, **kwargs):
return {
"worker": "test",
"type": "plot",
"task": task,
"iter": iteration,
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
**kwargs,
}
def test_get_plot_sample(self):
task = self._temp_task()
metric = "Metric1"
variant = "Variant1"
# test empty
res = self.api.events.get_plot_sample(
task=task, metric=metric, variant=variant
)
self.assertEqual(res.min_iteration, None)
self.assertEqual(res.max_iteration, None)
self.assertEqual(res.event, None)
# test existing events
iterations = 10
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=variant,
plot_str=f"Test plot str {n}",
)
for n in range(iterations)
]
self.send_batch(events)
# if iteration is not specified then return the event from the last one
res = self.api.events.get_plot_sample(
task=task, metric=metric, variant=variant
)
self._assertEqualEvent(res.event, events[-1])
self.assertEqual(res.max_iteration, iterations - 1)
self.assertEqual(res.min_iteration, 0)
self.assertTrue(res.scroll_id)
# else from the specific iteration
iteration = 8
res = self.api.events.get_plot_sample(
task=task,
metric=metric,
variant=variant,
iteration=iteration,
scroll_id=res.scroll_id,
)
self._assertEqualEvent(res.event, events[iteration])
def test_next_plot_sample(self):
task = self._temp_task()
metric1 = "Metric1"
variant1 = "Variant1"
metric2 = "Metric2"
variant2 = "Variant2"
metrics = [(metric1, variant1), (metric2, variant2)]
# test existing events
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=variant,
plot_str=f"Test plot str {n}",
)
for n in range(2)
for metric, variant in metrics
]
self.send_batch(events)
# single metric navigation
# init scroll
res = self.api.events.get_plot_sample(
task=task, metric=metric1, variant=variant1
)
self._assertEqualEvent(res.event, events[-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self.assertEqual(res.event, None)
# navigate backwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-4])
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, None)
# all metrics navigation
# init scroll
res = self.api.events.get_plot_sample(
task=task, metric=metric1, variant=variant1, navigate_current_metric=False
)
self._assertEqualEvent(res.event, events[-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self._assertEqualEvent(res.event, events[-1])
# navigate backwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-2])
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-3])
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]):
if ev2 is None:
self.assertIsNone(ev1)
return
self.assertIsNotNone(ev1)
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
self.assertEqual(ev1[field], ev2[field])
def test_task_plots(self):
task = self._temp_task()
# test empty
res = self.api.events.plots(metrics=[{"task": task}], iters=5)
self.assertFalse(res.metrics[0].iterations)
res = self.api.events.plots(
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
)
self.assertFalse(res.metrics[0].iterations)
# test not empty
metrics = {
"Metric1": ["Variant1", "Variant2"],
"Metric2": ["Variant3", "Variant4"],
}
events = [
self._create_task_event(
task=task,
iteration=1,
metric=metric,
variant=variant,
plot_str=f"Test plot str {metric}_{variant}",
)
for metric, variants in metrics.items()
for variant in variants
]
self.send_batch(events)
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=1
)
# test refresh
update = {
"Metric2": ["Variant3", "Variant4", "Variant5"],
"Metric3": ["VariantA", "VariantB"],
}
events = [
self._create_task_event(
task=task,
iteration=2,
metric=metric,
variant=variant,
plot_str=f"Test plot str {metric}_{variant}_2",
)
for metric, variants in update.items()
for variant in variants
]
self.send_batch(events)
# without refresh the metric states are not updated
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
)
# with refresh there are new metrics and existing ones are updated
self._assertTaskMetrics(
task=task,
expected_metrics=update,
iterations=1,
scroll_id=scroll_id,
refresh=True,
)
def _assertTaskMetrics(
self,
task: str,
expected_metrics: Mapping[str, Sequence[str]],
iterations,
scroll_id: str = None,
refresh=False,
) -> str:
res = self.api.events.plots(
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
)
if not iterations:
self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
for metric_data in res.metrics:
self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations:
self.assertEqual(
set((e.metric, e.variant) for e in it_data.events), expected_variants
)
return res.scroll_id
def test_plots_navigation(self):
task = self._temp_task()
metric = "Metric1"
variants = ["Variant1", "Variant2"]
iterations = 10
# test empty
res = self.api.events.plots(
metrics=[{"task": task, "metric": metric}], iters=5,
)
self.assertFalse(res.metrics[0].iterations)
# create events
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=variant,
plot_str=f"{metric}_{variant}_{n}",
)
for n in range(iterations)
for variant in variants
]
self.send_batch(events)
# init testing
scroll_id = None
assert_plots = partial(
self._assertPlots,
task=task,
metric=metric,
iterations=iterations,
variants=len(variants)
)
# test forward navigation
for page in range(3):
scroll_id = assert_plots(scroll_id=scroll_id, expected_page=page)
# test backwards navigation
scroll_id = assert_plots(
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
)
# beyond the latest iteration and back
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
scroll_id=scroll_id,
navigate_earlier=False,
)
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
assert_plots(scroll_id=scroll_id, expected_page=1)
# refresh
assert_plots(scroll_id=scroll_id, expected_page=0, refresh=True)
def _assertPlots(
self,
task,
metric,
iterations: int,
variants: int,
scroll_id,
expected_page: int,
iters: int = 5,
**extra_params,
) -> str:
res = self.api.events.plots(
metrics=[{"task": task, "metric": metric}],
iters=iters,
scroll_id=scroll_id,
**extra_params,
)
data = res["metrics"][0]
self.assertEqual(data["task"], task)
left_iterations = max(0, iterations - expected_page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:
self.assertEqual(len(it["events"]), variants)
return res.scroll_id
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)
return data

View File

@@ -0,0 +1,58 @@
import time
from apiserver.tests.automated import TestService
class TestTasksRunning(TestService):
STATUS_STOPPED = "stopped"
STATUS_COMPLETED = "completed"
STATUS_PUBLISHED = "published"
STATUS_RUNNING = "in_progress"
def test_stop_regular_task(self):
task_id = self._create_running_task()
data = self.api.tasks.stop(task=task_id).fields
assert data.status == self.STATUS_STOPPED
def test_stop_regular_task_with_active_worker(self):
task_id = self._create_running_task()
worker_id = "worker1"
self.api.workers.register(worker=worker_id)
self.api.workers.status_report(
worker=worker_id, task=task_id, timestamp=int(time.time())
)
data = self.api.tasks.stop(task=task_id).fields
assert data.status == self.STATUS_RUNNING
assert data.status_message == "stopping"
def test_stop_development_task(self):
task_id = self._create_running_task(is_development=True)
data = self.api.tasks.stop(task=task_id).fields
assert data.status == self.STATUS_STOPPED
def test_completed_task(self):
task_id = self._create_running_task()
res = self.api.tasks.completed(task=task_id)
assert res.fields.status == self.STATUS_COMPLETED
assert res.updated == 1
assert res.published == 0
res = self.api.tasks.completed(task=task_id, publish=True)
assert res.fields.status == self.STATUS_PUBLISHED
assert res.updated == 1
assert res.published == 1
def _create_running_task(self, is_development=False):
task_input = dict(
name="task-1",
type="testing",
input=dict(mapping={}, view=dict()),
)
if is_development:
task_input["system_tags"] = ["development"]
task_id = self.create_temp("tasks", **task_input)
self.api.tasks.started(task=task_id)
return task_id

View File

@@ -12,11 +12,8 @@ log = config.logger(__file__)
class TestWorkersService(TestService):
def setUp(self, version="2.4"):
super().setUp(version=version)
def _check_exists(self, worker: str, exists: bool = True):
workers = self.api.workers.get_all(last_seen=100).workers
def _check_exists(self, worker: str, exists: bool = True, tags: list = None):
workers = self.api.workers.get_all(last_seen=100, tags=tags).workers
found = any(w for w in workers if w.id == worker)
assert exists == found
@@ -40,6 +37,14 @@ class TestWorkersService(TestService):
time.sleep(5)
self._check_exists(test_worker, False)
def test_filters(self):
test_worker = f"test_{uuid4().hex}"
self.api.workers.register(worker=test_worker, tags=["application"], timeout=3)
self._check_exists(test_worker)
self._check_exists(test_worker, tags=["application", "test"])
self._check_exists(test_worker, False, tags=["test"])
self._check_exists(test_worker, False, tags=["-application"])
def _simulate_workers(self) -> Sequence[str]:
"""
Two workers writing the same metrics. One for 4 seconds. Another one for 2
@@ -74,7 +79,8 @@ class TestWorkersService(TestService):
self.api.workers.status_report(**data)
time.sleep(1)
return workers
res = self.api.workers.get_all(last_seen=100)
return [w.key for w in res.workers]
def _create_running_task(self, task_name):
task_input = dict(
@@ -104,7 +110,8 @@ class TestWorkersService(TestService):
def test_get_stats(self):
workers = self._simulate_workers()
to_date = utc_now_tz_aware()
to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(days=1)
# no variants
@@ -190,8 +197,8 @@ class TestWorkersService(TestService):
self._simulate_workers()
to_date = utc_now_tz_aware()
from_date = to_date - timedelta(minutes=10)
to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(minutes=1)
# no variants
res = self.api.workers.get_activity_report(

View File

@@ -1 +1 @@
__version__ = "1.3.0"
__version__ = "1.6.0"

View File

@@ -1,9 +1,11 @@
FROM centos/nodejs-12-centos7 AS webapp
ARG CLEARML_WEB_GIT_URL=https://github.com/allegroai/clearml-web.git
USER root
WORKDIR /opt
RUN git clone https://github.com/allegroai/clearml-web.git
RUN git clone ${CLEARML_WEB_GIT_URL} clearml-web
RUN mv clearml-web /opt/open-webapp
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
@@ -18,6 +20,7 @@ COPY --from=staging_image /opt/clearml/ /opt/clearml/
COPY --chmod=744 docker/build/internal_files/final_image_preparation.sh /tmp/internal_files/
COPY docker/build/internal_files/clearml.conf.template /tmp/internal_files/
COPY docker/build/internal_files/clearml_subpath.conf.template /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/final_image_preparation.sh'
COPY --from=webapp /opt/open-webapp/build /usr/share/nginx/html

View File

@@ -41,6 +41,7 @@ http {
server_name _;
root /usr/share/nginx/html;
proxy_http_version 1.1;
client_max_body_size 0;
# comppression
gzip on;

View File

@@ -0,0 +1,21 @@
location /${CLEARML_SERVER_SUB_PATH} {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass http://localhost:80;
rewrite /${CLEARML_SERVER_SUB_PATH}/(.*) /$1 break;
}
location /${CLEARML_SERVER_SUB_PATH}/api {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass http://localhost:80/api;
rewrite /${CLEARML_SERVER_SUB_PATH}/api/(.*) /api/$1 break;
}
location /${CLEARML_SERVER_SUB_PATH}/files {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass http://localhost:80/files;
rewrite /${CLEARML_SERVER_SUB_PATH}/files/(.*) /files/$1 break;
rewrite /${CLEARML_SERVER_SUB_PATH}/files /files/ break;
}

View File

@@ -48,9 +48,16 @@ EOF
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
envsubst '${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/nginx.conf
if [[ -n "${CLEARML_SERVER_SUB_PATH}" ]]; then
envsubst '${CLEARML_SERVER_SUB_PATH}' < /etc/nginx/clearml_subpath.conf.template > /etc/nginx/default.d/clearml_subpath.conf
cp /usr/share/nginx/html/env.js /usr/share/nginx/html/env.js.origin
envsubst '${CLEARML_SERVER_SUB_PATH}' < /usr/share/nginx/html/env.js.origin > /usr/share/nginx/html/env.js
cp /usr/share/nginx/html/index.html /usr/share/nginx/html/index.html.origin
sed 's/href="\/"/href="\/'${CLEARML_SERVER_SUB_PATH}'\/"/' /usr/share/nginx/html/index.html.origin > /usr/share/nginx/html/index.html
fi
#start the server
/usr/sbin/nginx -g "daemon off;"

View File

@@ -15,4 +15,5 @@ ln -s /dev/stdout /var/log/nginx/access.log
ln -s /dev/stderr /var/log/nginx/error.log
mv /etc/nginx/nginx.conf /etc/nginx/nginx.conf.orig
mv /tmp/internal_files/clearml.conf.template /etc/nginx/clearml.conf.template
mv /tmp/internal_files/clearml_subpath.conf.template /etc/nginx/clearml_subpath.conf.template
yum clean all

View File

@@ -108,6 +108,8 @@ services:
command:
- webserver
container_name: clearml-webserver
# environment:
# CLEARML_SERVER_SUB_PATH : clearml-web # Allow Clearml to be served with a URL path prefix.
image: allegroai/clearml:latest
restart: unless-stopped
depends_on:
@@ -152,6 +154,8 @@ services:
- /opt/clearml/agent:/root/.clearml
depends_on:
- apiserver
entrypoint: >
bash -c "curl --retry 10 --retry-delay 10 --retry-connrefused 'http://apiserver:8008/debug.ping' && /usr/agent/entrypoint.sh"
networks:
backend:

View File

@@ -1,5 +1,8 @@
# trains-server FAQ
## **NOTE**: This page's information is deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
Launching **trains-server**
* How do I launch **trains-server** on:

View File

@@ -1,5 +1,7 @@
# Deploying **trains-server** on AWS
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
To easily deploy **trains-server** on AWS, use one of our pre-built Amazon Machine Images (AMIs).
We provide AMIs per region for each released version of **trains-server**, see [Released versions](#released-versions) below.

View File

@@ -1,5 +1,7 @@
# Deploying Trains Server on Google Cloud Platform
# **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.

View File

@@ -1,5 +1,7 @@
# Launching the **trains-server** Docker in Linux or macOS
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
For Linux users:

View File

@@ -1,5 +1,7 @@
# Launching the **trains-server** Docker in Windows 10
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
For Windows, we recommend launching our pre-built Docker image on a Linux virtual machine.
However, you can launch **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).