Compare commits

43 Commits

Author SHA1 Message Date
allegroai
9c95c63ce0 Version bump to v1.15.0 2024-03-24 11:25:05 +02:00
allegroai
73179f53c2 Use latest patch versions for ES and Mongo 2024-03-24 11:24:51 +02:00
allegroai
ddc8a76279 Set API version to v2.29 2024-03-18 16:02:45 +02:00
allegroai
ac7ea0d477 Allow filtering task models.input.model field by array of ids 2024-03-18 16:01:45 +02:00
allegroai
3544ed19f8 Use latest patch versions for mongodb and ES 2024-03-18 15:59:15 +02:00
allegroai
5e68f053a0 Add widgets link in nginx configuration 2024-03-18 15:58:19 +02:00
allegroai
7bd5fdad59 Update webserver build: allow using external configuration from a file or from environment variables 2024-03-18 15:57:19 +02:00
allegroai
484c72aa0c Upgrade to Debian bookworm 2024-03-18 15:56:18 +02:00
allegroai
2027afbed5 Added missing ES index template for scalar events 2024-03-18 15:54:38 +02:00
allegroai
7d649f1964 Support controlling config value inheritance from the base folder 2024-03-18 15:53:58 +02:00
allegroai
8d237b3cae Upgrade Redis to v6.2 2024-03-18 15:53:07 +02:00
allegroai
e8ee6ce72e Code cleanup 2024-03-18 15:52:22 +02:00
allegroai
5749ff0454 Add security headers to webserver 2024-03-18 15:50:40 +02:00
allegroai
5189adf4f1 Improve handling of fixed users 2024-03-18 15:49:42 +02:00
allegroai
92a4e56c1f Add support for cookies extensions 2024-03-18 15:46:07 +02:00
allegroai
33528870ae Request cookies processing enhanced for more flexibility 2024-03-18 15:45:09 +02:00
allegroai
85f5b8b6f6 Fix last metrics for task are updated for events reported without variants 2024-03-18 15:44:28 +02:00
allegroai
6112910768 Make sure that legacy templates are deleted and empty db check is done on the new templates 2024-03-18 15:40:13 +02:00
allegroai
d3013ac285 Invalidate token on user logoff 2024-03-18 15:38:44 +02:00
allegroai
88abf28287 Add ElasticSearch 8.x support 2024-03-18 15:37:44 +02:00
allegroai
6a1fc04d1e Set cookies SameSite value to Lax 2024-02-13 16:18:21 +02:00
allegroai
ee8eb03698 Fix crash when importing events for public company tasks 2024-02-13 16:17:52 +02:00
allegroai
5799baae45 Make sure that APIs that aggregate task/model data from projects can be called for the root project 2024-02-13 16:17:33 +02:00
allegroai
801e536c5e Fix tasks.started to correctly handle null values in the started field 2024-02-13 16:17:02 +02:00
allegroai
6e484ea8f4 Fix missing region parameter when deleting files from minio server 2024-02-13 16:16:24 +02:00
allegroai
a47e65d974 Add input parameters check to multiple APIs 2024-02-13 16:15:55 +02:00
allegroai
702b6dc9c8 Version bump to v1.14.0 2024-01-10 15:31:11 +02:00
allegroai
db15f235e4 Make sure files downloaded from the apiserver are not cached by browsers 2024-01-10 15:31:01 +02:00
allegroai
8c347f8fa9 Fix include and exclude filters not processing "no tags" condition 2024-01-10 15:26:55 +02:00
allegroai
768c3d80ff Remove callback_url_prefix and state parameters from login.supported_modes and does not return urls 2024-01-10 15:26:22 +02:00
allegroai
a5c3ef6385 Fix query filter so that the default operator between different query operations on the same parameter is AND instead of OR 2024-01-10 15:24:37 +02:00
allegroai
11b7a384af Set API version 2.28 2024-01-10 15:23:54 +02:00
allegroai
9a70ade4a6 Support task models with missing model field in data_tool import 2024-01-10 15:18:58 +02:00
allegroai
91ce140901 Add "queue watched" indication to pipelines.start_pipeline 2024-01-10 15:15:43 +02:00
allegroai
49084a9c49 Optimize task statistics for projects dashboard and statistics reporter 2024-01-10 15:13:25 +02:00
allegroai
8a99eb6812 Fix model_metrics parameter name in get_multi_task_metrics schema 2024-01-10 15:12:56 +02:00
allegroai
811ab2bf4f Support exporting users with data tool 2024-01-10 15:12:07 +02:00
allegroai
3752db122b Add events.get_multi_task_metrics 2024-01-10 15:11:27 +02:00
allegroai
439911b84c Upgrade werkzeug and flask dependencies 2024-01-10 15:10:46 +02:00
allegroai
262a301e28 Check for dictionary type for some model and task fields 2024-01-10 15:10:41 +02:00
allegroai
a604451b01 Refactor check for tasks write permission 2024-01-10 15:08:20 +02:00
allegroai
88a7773621 Allow filtering on event metrics in multi-task endpoints get_task_single_value_metrics, multi_task_scalar_metrics_iter_histogram and get_multi_task_plots 2024-01-10 15:07:46 +02:00
allegroai
35c4061992 Support filtering by task or model ids in projects.get_unique_metric_variants 2024-01-10 15:06:21 +02:00
78 changed files with 1928 additions and 620 deletions

View File

@@ -13,6 +13,14 @@ from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum
class TaskRequest(Base):
task: str = StringField(required=True)
class ModelRequest(Base):
model: str = StringField(required=True)
class HistogramRequestBase(Base):
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
model_events: bool = BoolField(default=False)
class GetMetricsAndVariantsRequest(Base):
task: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(
items_types=str,
@@ -41,6 +54,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
)
],
)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
@@ -50,6 +64,12 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class LegacyMetricEventsRequest(TaskRequest):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
@@ -58,7 +78,14 @@ class MetricEventsRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField()
model_events: bool = BoolField(default=False)
class VectorMetricsIterHistogramRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class GetVariantSampleRequest(Base):
@@ -109,6 +136,11 @@ class TaskEventsRequest(TaskEventsRequestBase):
model_events: bool = BoolField(default=False)
class LegacyLogEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
scroll_id: str = StringField()
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
@@ -148,18 +180,28 @@ class MultiTasksRequestBase(Base):
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
last_iters_per_task_metric: bool = BoolField(default=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskPlotsRequest(Base):
@@ -171,6 +213,14 @@ class TaskPlotsRequest(Base):
model_events: bool = BoolField(default=False)
class GetScalarMetricDataRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):
scroll_id: str = StringField()

View File

@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
class GetSupportedModesRequest(Base):
state = StringField(help_text="ASCII base64 encoded application state")
callback_url_prefix = StringField()
pass
# state = StringField(help_text="ASCII base64 encoded application state")
# callback_url_prefix = StringField()
class BasicGuestMode(Base):

View File

@@ -42,6 +42,21 @@ class ModelRequest(models.Base):
model = fields.StringField(required=True)
class TaskRequest(models.Base):
task = fields.StringField(required=True)
class UpdateForTaskRequest(TaskRequest):
uri = fields.StringField()
iteration = fields.IntField()
override_model_id = fields.StringField()
class UpdateModelRequest(ModelRequest):
task = fields.StringField()
iteration = fields.IntField()
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)

View File

@@ -18,8 +18,4 @@ class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)
verify_watched_queue = fields.BoolField(default=False)

View File

@@ -33,6 +33,7 @@ class ProjectOrNoneRequest(models.Base):
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
ids = fields.ListField(str)
class GetParamsRequest(ProjectOrNoneRequest):
@@ -45,7 +46,7 @@ class ProjectTagsRequest(TagsRequest):
class MultiProjectRequest(models.Base):
projects = fields.ListField(str)
projects = fields.ListField(items_types=[str, type(None)])
include_subprojects = fields.BoolField(default=True)

View File

@@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class GetConfigRequest(Base):
path = StringField()
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()

View File

@@ -4,6 +4,10 @@ from jsonmodels.models import Base
from apiserver.apimodels import DictField
class UserRequest(Base):
user = StringField(required=True)
class CreateRequest(Base):
id = StringField(required=True)
name = StringField(required=True)

View File

@@ -31,6 +31,7 @@ from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.model import ModelBLL
from apiserver.bll.task.utils import get_many_tasks_for_writing
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
@@ -42,7 +43,7 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
@@ -55,7 +56,9 @@ MIN_LONG = -(2**63)
log = config.logger(__file__)
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
async_delete_threshold = config.get("services.tasks.async_events_delete_threshold", 100_000)
async_delete_threshold = config.get(
"services.tasks.async_events_delete_threshold", 100_000
)
class EventBLL(object):
@@ -97,7 +100,9 @@ class EventBLL(object):
return self._metrics
@staticmethod
def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set:
def _get_valid_entities(
company_id, ids: Mapping[str, bool], identity: Identity, model=False
) -> Set:
"""Verify that task or model exists and can be updated"""
if not ids:
return set()
@@ -116,20 +121,34 @@ class EventBLL(object):
):
if not requested_ids:
continue
query = Q(id__in=requested_ids, company=company_id)
res.update(
(Model if model else Task).objects(query & locked_q).scalar("id")
)
query = Q(id__in=requested_ids) & locked_q
if model:
ids = Model.objects(query & Q(company=company_id)).scalar("id")
else:
ids = {
t.id
for t in get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=query,
only=("id",),
throw_on_forbidden=False,
)
}
res.update(ids)
return res
def add_events(
self,
company_id: str,
user_id: str,
identity: Identity,
events: Sequence[dict],
worker: str,
) -> Tuple[int, int, dict]:
user_id = identity.user
task_ids = {}
model_ids = {}
for event in events:
@@ -161,8 +180,12 @@ class EventBLL(object):
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True)
valid_tasks = self._get_valid_entities(company_id, ids=task_ids)
valid_models = self._get_valid_entities(
company_id, ids=model_ids, identity=identity, model=True
)
valid_tasks = self._get_valid_entities(
company_id, ids=task_ids, identity=identity
)
actions: List[dict] = []
used_task_ids = set()
@@ -351,7 +374,7 @@ class EventBLL(object):
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
[{"_index": invalid_iteration_error}],
)
if not added:
@@ -415,10 +438,8 @@ class EventBLL(object):
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
metric = event.get("metric")
variant = event.get("variant")
if not (metric and variant):
return
metric = event.get("metric") or ""
variant = event.get("variant") or ""
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
@@ -463,9 +484,9 @@ class EventBLL(object):
recent than the currently stored event for its metric/event_type combination.
last_events contains [metric_name -> event_type -> event]
"""
metric = event.get("metric")
metric = event.get("metric") or ""
event_type = event.get("type")
if not (metric and event_type):
if not event_type:
return
timestamp = last_events[metric][event_type].get("timestamp", None)
@@ -637,8 +658,8 @@ class EventBLL(object):
Return events and next scroll id from the scrolled query
Release the scroll once it is exhausted
"""
total_events = safe_get(es_res, "hits/total/value", default=0)
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
total_events = nested_get(es_res, ("hits", "total", "value"), default=0)
events = [doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.clear_scroll(next_scroll_id)

View File

@@ -9,7 +9,7 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
class EventType(Enum):
@@ -123,8 +123,8 @@ def get_max_metric_and_variant_counts(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)
metrics_count = safe_get(
es_res, "aggregations/metrics_count/value", max_metrics_count
metrics_count = nested_get(
es_res, ("aggregations", "metrics_count", "value"), max_metrics_count
)
if not metrics_count:
return max_metrics_count, max_variants_count

View File

@@ -21,9 +21,10 @@ from apiserver.bll.event.event_common import (
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
log = config.logger(__file__)
@@ -161,7 +162,9 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, companies: TaskCompanies
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
@@ -179,7 +182,13 @@ class EventMetrics:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
pool.map(
partial(
self._get_task_single_value_metrics,
metric_variants=metric_variants,
),
companies.items(),
)
),
)
@@ -195,19 +204,19 @@ class EventMetrics:
}
def _get_task_single_value_metrics(
self, tasks: Tuple[str, Sequence[str]]
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
) -> Sequence[dict]:
company_id, task_ids = tasks
must = [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
@@ -280,7 +289,8 @@ class EventMetrics:
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -332,12 +342,12 @@ class EventMetrics:
total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
"""
count = safe_get(data, "count/value", default=0)
count = nested_get(data, ("count", "value"), default=0)
if count < samples:
return metric, variant, 1, count
min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(data, "max_index/value", default=min_index)
min_index = nested_get(data, ("min_index", "value"), default=0)
max_index = nested_get(data, ("max_index", "value"), default=min_index)
index_range = max_index - min_index + 1
interval = max(1, math.ceil(float(index_range) / samples))
max_samples = math.ceil(float(index_range) / interval)
@@ -366,7 +376,8 @@ class EventMetrics:
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -432,7 +443,9 @@ class EventMetrics:
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
cls,
task_id: str,
metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:
@@ -451,12 +464,96 @@ class EventMetrics:
return {"bool": {"must": must}}
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
"""
For the requested tasks return reported metrics and variants
"""
tasks_ids = {
company: [t.id for t in tasks]
for company, tasks in companies.items()
}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
companies_res: Sequence = list(
pool.map(
partial(
self._get_multi_task_metrics,
event_type=event_type,
),
tasks_ids.items(),
)
)
if len(companies_res) == 1:
return companies_res[0]
res = defaultdict(set)
for c_res in companies_res:
for m, vars_ in c_res.items():
res[m].update(vars_)
return {
k: list(v)
for k, v in res.items()
}
def _get_multi_task_metrics(
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
) -> Mapping[str, list]:
company_id, task_ids = company_tasks
if check_empty_data(self.es, company_id, event_type):
return {}
search_args = dict(
es=self.es,
company_id=company_id,
event_type=event_type,
)
query = QueryBuilder.terms("task", task_ids)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
}
}
}
},
}
es_res = search_company_events(
body=es_req,
**search_args,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
return {
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
for mb in aggs_result["metrics"]["buckets"]
}
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
For the requested tasks return reported metrics per task
"""
if check_empty_data(self.es, company_id, event_type):
return {}
@@ -495,5 +592,5 @@ class EventMetrics:
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"), default=[])
]

View File

@@ -6,7 +6,6 @@ from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Callable
import attr
import dpath
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
@@ -27,6 +26,7 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
@@ -305,13 +305,13 @@ class MetricEventsIterator:
return [
MetricState(
metric=metric["key"],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
timestamp=nested_get(metric, ("last_event_timestamp", "value")),
variants=[
init_variant_state(variant)
for variant in dpath.get(metric, "variants/buckets")
for variant in nested_get(metric, ("variants", "buckets"))
],
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"))
]
@abc.abstractmethod
@@ -430,14 +430,14 @@ class MetricEventsIterator:
def get_iteration_events(it_: dict) -> Sequence:
return [
self._process_event(ev["_source"])
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")
for m in nested_get(it_, ("metrics", "buckets"))
for v in nested_get(m, ("variants", "buckets"))
for ev in nested_get(v, ("events", "hits", "hits"))
if is_valid_event(ev["_source"])
]
iterations = []
for it in dpath.get(es_res, "aggregations/iters/buckets"):
for it in nested_get(es_res, ("aggregations", "iters", "buckets")):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})

View File

@@ -10,6 +10,7 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo.auth import Identity
from .metadata import Metadata
@@ -57,14 +58,15 @@ class ModelBLL:
cls,
model_id: str,
company_id: str,
user_id: str,
identity: Identity,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, str, bool], dict] = None,
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
user_id = identity.user
published_task = None
if model.task and publish_task_func:
task = (
@@ -74,7 +76,7 @@ class ModelBLL:
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, user_id, force_publish_task
model.task, company_id, identity, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res

View File

@@ -341,6 +341,17 @@ class ProjectBLL:
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def project_task_fields():
return {
"$project": {
"project": 1,
"status": 1,
"system_tags": 1,
"started": 1,
"completed": 1,
}
}
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
@@ -368,6 +379,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
"$group": {
@@ -516,6 +528,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
# for each project
@@ -856,7 +869,7 @@ class ProjectBLL:
company,
project_ids: Sequence[str],
user_ids: Optional[Sequence[str]] = None,
) -> Set[str]:
) -> Set[Union[str, type(None)]]:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
@@ -1112,11 +1125,7 @@ class ProjectBLL:
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
op = (
Q.OR
if helper.explicit_operator and helper.global_operator == Q.OR
else Q.AND
)
op = helper.global_operator
db_query = {op: helper.actions}
else:
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
@@ -1125,7 +1134,7 @@ class ProjectBLL:
for op, actions in db_query.items():
field_conditions = {}
for action, values in actions.items():
value = list(set(values))
value = list(set(values)) if isinstance(values, list) else values
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)

View File

@@ -239,6 +239,7 @@ class ProjectQueries:
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
ids: Sequence[str],
model_metrics: bool = False,
):
pipeline = [
@@ -246,6 +247,7 @@ class ProjectQueries:
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
**({"_id": {"$in": ids}} if ids else {}),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},

View File

@@ -152,7 +152,7 @@ class QueueBLL(object):
for item in queue.entries:
try:
task = Task.get_for_writing(
task = Task.get(
company=company_id,
id=item.task,
_only=[

View File

@@ -18,7 +18,7 @@ from apiserver.config.info import get_deployment_type
from apiserver.database.model import Company, User
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import dumps
from apiserver.version import __version__ as current_version
from .resource_monitor import ResourceMonitor, stat_threads
@@ -162,7 +162,7 @@ class StatisticsReporter:
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
names = {"cpu": "num_cores"}
return {
names[c["key"]]: safe_get(c, "count/value")
names[c["key"]]: nested_get(c, ("count", "value"))
for c in categories
if c["key"] in names
}
@@ -175,21 +175,21 @@ class StatisticsReporter:
}
return {
names[m["key"]]: {
"min": safe_get(m, "min/value"),
"max": safe_get(m, "max/value"),
"avg": safe_get(m, "avg/value"),
"min": nested_get(m, ("min", "value")),
"max": nested_get(m, ("max", "value")),
"avg": nested_get(m, ("avg", "value")),
}
for m in metrics
if m["key"] in names
}
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {
key: {
"interval_sec": agent_resource_threshold_sec,
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
**_get_cardinality_fields(nested_get(b, ("categories", "buckets"), [])),
**_get_metric_fields(nested_get(b, ("metrics", "buckets"), [])),
}
}
for b in buckets
@@ -227,7 +227,7 @@ class StatisticsReporter:
},
}
res = cls._run_worker_stats_query(company_id, es_req)
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
for b in buckets
@@ -254,6 +254,14 @@ class StatisticsReporter:
**({"last_worker": {"$in": workers}} if workers else {}),
}
},
{
"$project": {
"last_worker": 1,
"last_update": 1,
"started": 1,
"last_iteration": 1,
}
},
{
"$group": {
"_id": "$last_worker" if workers else None,

View File

@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -48,12 +49,14 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifacts = {
get_artifact_id(a): Artifact(**a)
@@ -64,18 +67,20 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifact_ids = [
get_artifact_id(a)
@@ -85,4 +90,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.service_repo.auth import Identity
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -31,7 +32,10 @@ class HyperParams:
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -63,7 +67,7 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
@@ -74,6 +78,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
with_param, without_param = iterutils.partition(
@@ -96,7 +101,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@@ -105,7 +110,7 @@ class HyperParams:
def edit_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
@@ -117,6 +122,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
update_cmds = dict()
@@ -135,7 +141,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@@ -163,7 +169,10 @@ class HyperParams:
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -209,13 +218,15 @@ class HyperParams:
def edit_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
update_cmds = dict()
configuration = {
@@ -228,22 +239,24 @@ class HyperParams:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -58,27 +58,6 @@ class TaskBLL:
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
@staticmethod
def get_by_id(
company_id,

View File

@@ -9,6 +9,7 @@ from apiserver.bll.task import (
ChangeStatusRequest,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
@@ -24,6 +25,7 @@ from apiserver.database.model.task.task import (
DEFAULT_LAST_ITERATION,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
@@ -33,7 +35,7 @@ queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task],
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
@@ -42,9 +44,10 @@ def archive_task(
Return 1 if successful
"""
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
@@ -54,8 +57,9 @@ def archive_task(
"system_tags",
"enqueue_status",
),
requires_write_access=True,
)
user_id = identity.user
try:
TaskBLL.dequeue_and_change_status(
task,
@@ -79,34 +83,34 @@ def archive_task(
def unarchive_task(
task: str,
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task,
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=("id",),
requires_write_access=True,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
last_changed_by=identity.user,
)
def dequeue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
remove_from_all_queues: bool = False,
@@ -119,7 +123,19 @@ def dequeue_task(
task = Task.get(
id=task_id,
company=company_id,
_only=(
_only=("id",),
include_public=True,
)
if not task:
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
@@ -127,11 +143,7 @@ def dequeue_task(
"project",
"enqueue_status",
),
include_public=True,
)
if not task:
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
res = TaskBLL.dequeue_and_change_status(
task,
@@ -148,7 +160,7 @@ def dequeue_task(
def enqueue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
queue_id: str,
status_message: str,
status_reason: str,
@@ -173,11 +185,11 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
user_id = identity.user
if validate:
TaskBLL.validate(task)
@@ -207,9 +219,9 @@ def enqueue_task(
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
nested_set(res, ("fields", "execution.queue"), queue_id)
return 1, res
@@ -242,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
def delete_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
@@ -251,8 +263,9 @@ def delete_task(
status_reason: str,
delete_external_artifacts: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if (
@@ -305,15 +318,16 @@ def delete_task(
def reset_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force and task.status == TaskStatus.published:
@@ -392,14 +406,15 @@ def reset_task(
def publish_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
publish_model_func: Callable[[str, str, str], Any] = None,
publish_model_func: Callable[[str, str, Identity], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force:
validate_status_change(task.status, TaskStatus.published)
@@ -422,7 +437,7 @@ def publish_task(
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id, user_id)
publish_model_func(model.id, company_id, identity)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
@@ -446,7 +461,7 @@ def publish_task(
def stop_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
user_name: str,
status_reason: str,
force: bool,
@@ -459,10 +474,11 @@ def stop_task(
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"status",
"project",
@@ -472,7 +488,6 @@ def stop_task(
"last_update",
"execution.queue",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:

View File

@@ -1,7 +1,9 @@
from datetime import datetime
from typing import Sequence
import attr
import six
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
@@ -10,6 +12,7 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@@ -157,15 +160,78 @@ def get_possible_status_changes(current_status):
return possible
def get_many_tasks_for_writing(
company_id: str,
identity: Identity,
query: Q = None,
only: Sequence = None,
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
missing = [f for f in ("company", ) if f not in only]
if missing:
only = [*only, *missing]
result = list(
Task.get_many(
company=company_id,
query=query,
override_projection=only,
allow_public=True,
return_dicts=False,
)
)
if not company_id:
return result
forbidden_tasks = {task.id for task in result if not task.company}
if forbidden_tasks:
if throw_on_forbidden:
raise errors.forbidden.NoWritePermission(
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
)
result = [task for task in result if task.id not in forbidden_tasks]
return result
def get_task_with_write_access(
task_id: str,
company_id: str,
identity: Identity,
only=None,
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(_only=only, **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str,
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
only=("id", "status"),
identity=identity,
)
if allow_all_statuses:
return task

View File

@@ -27,10 +27,9 @@ from apiserver.database.model.project import Project
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from .stats import WorkerStats
log = config.logger(__file__)
@@ -287,7 +286,7 @@ class WorkerBLL:
filter(
None,
(
safe_get(info, "next_entry/task")
nested_get(info, ("next_entry", "task"))
for info in queues_info.values()
),
)
@@ -311,7 +310,7 @@ class WorkerBLL:
continue
entry.name = info.get("name", None)
entry.num_tasks = info.get("num_entries", 0)
task_id = safe_get(info, "next_entry/task")
task_id = nested_get(info, ("next_entry", "task"))
if task_id:
task = tasks_info.get(task_id, None)
entry.next_task = IdNameEntry(

View File

@@ -6,7 +6,7 @@ from functools import reduce
from os import getenv
from os.path import expandvars
from pathlib import Path
from typing import List, Any, TypeVar, Sequence
from typing import List, Any, TypeVar, Sequence, Set
from boltons.iterutils import first
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
@@ -35,6 +35,7 @@ class BasicConfig:
folder: str = None,
verbose: bool = True,
prefix: Sequence[str] = DEFAULT_PREFIXES,
exclude_files_from_base_folder: Sequence[str] = None,
):
folder = (
Path(folder)
@@ -44,6 +45,11 @@ class BasicConfig:
if not folder.is_dir():
raise ValueError("Invalid configuration folder")
self.exclude_files_from_base_folder = (
set(exclude_files_from_base_folder)
if exclude_files_from_base_folder
else set()
)
self.verbose = verbose
self.extra_config_path_override_var = [
@@ -85,7 +91,7 @@ class BasicConfig:
return logging.getLogger(path)
def _read_extra_env_config_values(self) -> ConfigTree:
""" Loads extra configuration from environment-injected values """
"""Loads extra configuration from environment-injected values"""
result = ConfigTree()
for prefix in self.extra_config_values_env_key_prefix:
@@ -125,12 +131,18 @@ class BasicConfig:
def _reload(self) -> ConfigTree:
extra_config_values = self._read_extra_env_config_values()
configs = [self._read_recursive(path) for path in self._paths]
configs = [
self._read_recursive(
path,
exclude_files=(
self.exclude_files_from_base_folder if idx == 0 else None
),
)
for idx, path in enumerate(self._paths)
]
return reduce(
lambda last, config: self._merge_configs(
last, config, copy_trees=True
),
lambda last, config: self._merge_configs(last, config, copy_trees=True),
configs + [extra_config_values],
ConfigTree(),
)
@@ -141,9 +153,14 @@ class BasicConfig:
for key, value in b.items():
override = key.startswith(override_prefix)
if override:
key = key[len(override_prefix):]
key = key[len(override_prefix) :]
# if key is in both a and b and both values are dictionary then merge it otherwise override it
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
if (
not override
and key in a
and isinstance(a[key], ConfigTree)
and isinstance(b[key], ConfigTree)
):
if copy_trees:
a[key] = a[key].copy()
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
@@ -156,13 +173,15 @@ class BasicConfig:
a[key] = value
if a.root:
if b.root:
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
a.history[key] = a.history.get(key, []) + b.history.get(
key, [value]
)
else:
a.history[key] = a.history.get(key, []) + [value]
return a
def _read_recursive(self, conf_root) -> ConfigTree:
def _read_recursive(self, conf_root, exclude_files: Set[str]) -> ConfigTree:
conf = ConfigTree()
if not conf_root:
@@ -180,6 +199,8 @@ class BasicConfig:
print(f"Loading config from {conf_root}")
for file in conf_root.rglob("*.conf"):
if exclude_files and file.name in exclude_files:
continue
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
conf.put(key, self._read_single_file(file))

View File

@@ -58,6 +58,9 @@
# verify user tokens
verify_user_tokens: false
# If set then users that were created from secure credentials or fixed user settings and are no longer in these settings will be deleted on startup
delete_missing_autocreated_users: true
# max token expiration timeout in seconds (1 year)
max_expiration_sec: 31536000
@@ -72,6 +75,7 @@
httponly: true # allow only http to access the cookies (no JS etc)
secure: false # not using HTTPS
domain: null # Limit to localhost is not supported
samesite: Lax
max_age: 99999999999
}

View File

@@ -2,10 +2,9 @@ fileserver = "http://localhost:8081"
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
hosts: [{host: "127.0.0.1", port: 9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}
@@ -13,10 +12,9 @@ elastic {
}
workers {
hosts: [{host:"127.0.0.1", port:9200}]
hosts: [{host:"127.0.0.1", port:9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}

View File

@@ -18,8 +18,9 @@ aws {
{
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
host: "localhost:9000"
key: "evg_user"
secret: "evg_pass"
key: "minioadmin"
secret: "minioadmin"
# region: my-server
multipart: false
secure: false
}

View File

@@ -5,7 +5,7 @@ from textwrap import shorten
import dpath
from dpath.exceptions import InvalidKeyName
from elasticsearch import ElasticsearchException
from elastic_transport import TransportError, ApiError
from elasticsearch.helpers import BulkIndexError
from jsonmodels.errors import ValidationError as JsonschemaValidationError
from mongoengine.errors import (
@@ -210,9 +210,9 @@ def translate_errors_context(message=None, **kwargs):
raise errors.bad_request.ValidationError(e.args[0])
except BulkIndexError as e:
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
except ElasticsearchException as e:
except (TransportError, ApiError) as e:
raise errors.server_error.DataError(e, message, **kwargs)
except InvalidKeyName:
raise errors.server_error.DataError("invalid empty key encountered in data")
except Exception as ex:
except Exception:
raise

View File

@@ -4,6 +4,7 @@ from mongoengine import (
EmbeddedDocumentListField,
EmailField,
DateTimeField,
BooleanField,
)
from apiserver.database import Database, strict
@@ -76,3 +77,6 @@ class User(DbModelMixin, AuthDocument):
email = EmailField(unique=True, sparse=True)
""" Email uniquely identifying the user """
autocreated = BooleanField(default=False)
""" Set to true if the user was auto created based on config settings"""

View File

@@ -146,9 +146,10 @@ class GetMixin(PropsMixin):
"__$any": Q.OR,
"__$or": Q.OR,
}
default_operator = Q.OR
default_global_operator = Q.AND
default_context = Q.OR
# not_all modifier currently not supported due to the backwards compatibility
mongo_modifiers = {
# not_all modifier currently not supported due to the backwards compatibility
Q.AND: {True: "all", False: "nin"},
Q.OR: {True: "in", False: "nin"},
}
@@ -165,24 +166,22 @@ class GetMixin(PropsMixin):
self.allow_empty = False
self.global_operator = None
self.actions = defaultdict(list)
self.explicit_operator = False
self._support_legacy = legacy
current_context = self.default_operator
current_context = self.default_context
for d in self._get_next_term(data):
if d.operator is not None:
current_context = d.operator
self._support_legacy = False
if self.global_operator is None:
self.global_operator = d.operator
self.explicit_operator = True
continue
if self.global_operator is None:
self.global_operator = self.default_operator
self.global_operator = self.default_global_operator
if d.reset:
current_context = self.default_operator
current_context = self.default_context
self._support_legacy = legacy
continue
@@ -195,7 +194,7 @@ class GetMixin(PropsMixin):
)
if self.global_operator is None:
self.global_operator = self.default_operator
self.global_operator = self.default_global_operator
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
unary_operator = None
@@ -618,7 +617,20 @@ class GetMixin(PropsMixin):
):
if not vals:
continue
operations[self._db_modifiers[(op, include)]] = list(set(vals))
unique = set(vals)
if None in unique:
# noinspection PyTypeChecker
unique.remove(None)
if include:
operations["size"] = 0
else:
operations["not__size"] = 0
if not unique:
continue
operations[self._db_modifiers[(op, include)]] = list(unique)
self.db_query[op] = operations
@@ -656,7 +668,8 @@ class GetMixin(PropsMixin):
ops = []
for action, vals in actions.items():
if not vals:
# cannot just check vals here since 0 is acceptable value
if vals is None or vals == []:
continue
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
@@ -1283,22 +1296,6 @@ class GetMixin(PropsMixin):
)
return result
@classmethod
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
company=company,
*args,
**dict(return_dicts=False, **kwargs),
allow_public=True,
)
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
)
return result
class UpdateMixin(object):
__user_set_allowed_fields = None

View File

@@ -231,11 +231,12 @@ class Task(AttributedDocument):
"parent",
"hyperparams.*",
"execution.queue",
"models.input.model",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment", "report"),
fields=("runtime.*", "models.input.model"),
fields=("runtime.*",),
)
id = StringField(primary_key=True)

View File

@@ -4,34 +4,89 @@ Apply elasticsearch mappings to given hosts.
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional, Sequence, Tuple
from elasticsearch import Elasticsearch
from elasticsearch import Elasticsearch, exceptions
HERE = Path(__file__).resolve().parent
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
hosts: Sequence,
key: Optional[str] = None,
es_args: dict = None,
http_auth: Tuple = None,
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
def _send_template(f):
with f.open() as json_data:
data = json.load(json_data)
template_name = f.stem
res = es.indices.put_template(name=template_name, body=data)
return {"mapping": template_name, "result": res}
def _send_component_template(ct_file):
with ct_file.open() as json_data:
body = json.load(json_data)
template_name = f"{ct_file.stem}"
res = es.cluster.put_component_template(name=template_name, body=body)
return {"component_template": template_name, "result": res}
p = HERE / "mappings"
if key:
files = (p / key).glob("*.json")
else:
files = p.glob("**/*.json")
def _send_index_template(it_file):
with it_file.open() as json_data:
body = json.load(json_data)
template_name = f"{it_file.stem}"
res = es.indices.put_index_template(name=template_name, body=body)
return {"index_template": template_name, "result": res}
# def _send_legacy_template(f):
# with f.open() as json_data:
# data = json.load(json_data)
# template_name = f.stem
# res = es.indices.put_template(name=template_name, body=data)
# return {"mapping": template_name, "result": res}
def _delete_legacy_templates(legacy_folder):
res_list = []
for lt in legacy_folder.glob("*.json"):
template_name = lt.stem
try:
if not es.indices.get_template(name=template_name):
continue
res = es.indices.delete_template(name=template_name)
except exceptions.NotFoundError:
continue
res_list.append({"deleted legacy mapping": template_name, "result": res})
return res_list
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
return [_send_template(f) for f in files]
root = HERE / "index_templates"
if key:
folders = [root / key]
else:
folders = [f for f in root.iterdir() if f.is_dir()]
ret = []
for f in folders:
for ct in (f / "component_templates").glob("*.json"):
ret.append(_send_component_template(ct))
for it in f.glob("*.json"):
ret.append(_send_index_template(it))
legacy_root = HERE / "mappings"
for f in folders:
legacy_f = legacy_root / f.stem
if not legacy_f.exists() or not legacy_f.is_dir():
continue
ret.extend(_delete_legacy_templates(legacy_f))
return ret
# p = HERE / "mappings"
# if key:
# files = (p / key).glob("*.json")
# else:
# files = p.glob("**/*.json")
#
# return [_send_template(f) for f in files]
def parse_args():

View File

@@ -0,0 +1,48 @@
{
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"@timestamp": {
"type": "date"
},
"task": {
"type": "keyword"
},
"type": {
"type": "keyword"
},
"worker": {
"type": "keyword"
},
"timestamp": {
"type": "date"
},
"iter": {
"type": "long"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"company_id": {
"type": "keyword"
},
"model_event": {
"type": "boolean"
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
{
"index_patterns": "events-log-*",
"template": {
"mappings": {
"properties": {
"msg": {
"type": "text",
"index": false
},
"level": {
"type": "keyword"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,18 @@
{
"index_patterns": "events-plot-*",
"template": {
"mappings": {
"properties": {
"plot_str": {
"type": "text",
"index": false
},
"plot_data": {
"type": "binary"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,17 @@
{
"index_patterns": "events-training_debug_image-*",
"template": {
"mappings": {
"properties": {
"key": {
"type": "keyword"
},
"url": {
"type": "keyword"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,5 @@
{
"index_patterns": "events-training_stats_scalar-*",
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,31 @@
{
"index_patterns": "queue_metrics_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
},
"company_id": {
"type": "keyword"
}
}
}
}
}

View File

@@ -0,0 +1,43 @@
{
"index_patterns": "worker_stats_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"worker": {
"type": "keyword"
},
"category": {
"type": "keyword"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"unit": {
"type": "keyword"
},
"task": {
"type": "keyword"
},
"company_id": {
"type": "keyword"
}
}
}
}
}

View File

@@ -10,6 +10,8 @@ from apiserver.config_repo import config
from apiserver.elastic.apply_mappings import apply_mappings_to_cluster
log = config.logger(__file__)
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
class MissingElasticConfiguration(Exception):
@@ -78,6 +80,18 @@ def check_elastic_empty() -> bool:
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
)
def events_legacy_template():
try:
return es.indices.get_template(name="events*")
except exceptions.NotFoundError:
return False
def events_template():
try:
return es.indices.get_index_template(name="events*")
except exceptions.NotFoundError:
return False
try:
es_logger.addFilter(log_filter)
for retry in range(max_retries):
@@ -87,10 +101,7 @@ def check_elastic_empty() -> bool:
http_auth=es_factory.get_credentials("events", cluster_conf),
**cluster_conf.get("args", {}),
)
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
return True
return not (events_template() or events_legacy_template())
except exceptions.ConnectionError as ex:
if retry >= max_retries - 1:
raise ElasticConnectionError(

View File

@@ -1,3 +1,4 @@
import logging
from datetime import datetime
from functools import lru_cache
from os import getenv
@@ -9,6 +10,8 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
log = config.logger(__file__)
logging.getLogger('elasticsearch').setLevel(logging.WARNING)
logging.getLogger('elastic_transport').setLevel(logging.WARNING)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_ELASTIC_SERVICE_HOST",
@@ -32,6 +35,7 @@ if OVERRIDE_HOST:
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
OVERRIDE_PORT = int(OVERRIDE_PORT)
log.info(f"Using override elastic port {OVERRIDE_PORT}")
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))

View File

@@ -450,6 +450,7 @@ class AWSStorage(Storage):
else None,
"use_ssl": cfg.secure,
"verify": cfg.verify,
"region_name": cfg.region or None,
}
name = base[len(scheme_prefix(self.scheme)) :]
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name

View File

@@ -3,7 +3,7 @@ from typing import Sequence, Union
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model.auth import Role
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate
@@ -60,14 +60,18 @@ def init_mongo_data():
fixed_mode = FixedUser.enabled()
internal_user_emails = set()
for user, credentials in config.get("secure.credentials", {}).items():
email = f"{user}@example.com"
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"email": email,
"key": credentials.user_key,
"secret": credentials.user_secret,
"autocreated": True,
}
internal_user_emails.add(email.lower())
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
@@ -82,8 +86,20 @@ def init_mongo_data():
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, log=log)
ensure_fixed_user(user, log=log, emails=internal_user_emails)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
if internal_user_emails and config.get(
f"apiserver.auth.delete_missing_autocreated_users", True
):
for user in AuthUser.objects(
company=company_id, autocreated=True, email__nin=internal_user_emails
):
log.info(
f"Removing user that is no longer in configuration: {user['id']}\t{user['email']}\t{user['name']}"
)
user.delete()
except Exception as ex:
log.exception("Failed initializing mongodb")
log.exception(f"Failed initializing mongodb: {str(ex)}")

View File

@@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import (
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
TaskModelNames,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -66,6 +68,7 @@ class PrePopulate:
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
users_filename = "users.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
img_source_regex = re.compile(
@@ -78,6 +81,7 @@ class PrePopulate:
project_cls: Type[Project]
model_cls: Type[Model]
user_cls: Type[User]
auth_user_cls: Type[AuthUser]
# noinspection PyTypeChecker
@classmethod
@@ -90,6 +94,8 @@ class PrePopulate:
cls.project_cls = cls._get_entity_type("database.model.project.Project")
if not hasattr(cls, "user_cls"):
cls.user_cls = cls._get_entity_type("database.model.User")
if not hasattr(cls, "auth_user_cls"):
cls.auth_user_cls = cls._get_entity_type("database.model.auth.User")
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
@@ -205,6 +211,8 @@ class PrePopulate:
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
export_events: bool = True,
export_users: bool = False,
) -> Sequence[str]:
cls._init_entity_types()
@@ -240,11 +248,15 @@ class PrePopulate:
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
if export_users:
cls._export_users(zfile)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
export_events=export_events,
cleanup_users=not export_users,
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
@@ -265,6 +277,9 @@ class PrePopulate:
metadata_hash=metadata_hash,
)
if created_files:
print("Created files:\n" + "\n".join(file for file in created_files))
return created_files
@classmethod
@@ -296,18 +311,26 @@ class PrePopulate:
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
user_mapping = cls._import_users(zfile, company_id)
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata)
cls._import(
zfile,
company_id=company_id,
user_id=user_id,
metadata=metadata,
user_mapping=user_mapping,
)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
@@ -438,7 +461,7 @@ class PrePopulate:
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
entities = defaultdict(set)
entities: Dict[Any] = defaultdict(set)
if projects:
print("Reading projects...")
@@ -497,7 +520,6 @@ class PrePopulate:
@classmethod
def _cleanup_model(cls, model: Model):
model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags)
@classmethod
@@ -505,7 +527,6 @@ class PrePopulate:
task.comment = "Auto generated by Allegro.ai"
task.status_message = ""
task.status_reason = ""
task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
@@ -513,17 +534,32 @@ class PrePopulate:
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
def _cleanup_auth_user(cls, user: AuthUser):
user.company = ""
for cred in user.credentials:
if getattr(cred, "company", None):
cred["company"] = ""
return user
@classmethod
def _cleanup_be_user(cls, user: User):
user.company = ""
user.preferences = None
return user
@classmethod
def _cleanup_entity(cls, entity_cls, entity, cleanup_users):
if cleanup_users:
entity.user = ""
if entity_cls == cls.task_cls:
cls._cleanup_task(entity)
elif entity_cls == cls.model_cls:
cls._cleanup_model(entity)
elif entity == cls.project_cls:
elif entity_cls == cls.project_cls:
cls._cleanup_project(entity)
@classmethod
@@ -633,6 +669,38 @@ class PrePopulate:
else:
print(f"Artifact {full_path} not found")
@classmethod
def _export_users(cls, writer: ZipFile):
auth_users = {
user.id: cls._cleanup_auth_user(user)
for user in cls.auth_user_cls.objects(role__in=(Role.admin, Role.user))
}
if not auth_users:
return
be_users = {
user.id: cls._cleanup_be_user(user)
for user in cls.user_cls.objects(id__in=list(auth_users))
}
if not be_users:
return
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
print(f"Writing {len(auth_users)} users into {writer.filename}")
data = {}
for field, users in (("auth", auth_users), ("backend", be_users)):
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for user in users.values():
w.write(user.to_json())
data[field] = f.getvalue()
def get_field_bytes(k: str, v: bytes) -> bytes:
return f'"{k}": '.encode("utf-8") + v
data_str = b",\n".join(get_field_bytes(k, v) for k, v in data.items())
writer.writestr(cls.users_filename, b"{\n" + data_str + b"\n}")
@classmethod
def _get_base_filename(cls, cls_: type):
name = f"{cls_.__module__}.{cls_.__name__}"
@@ -642,7 +710,13 @@ class PrePopulate:
@classmethod
def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
cls,
writer: ZipFile,
entities: dict,
hash_,
tag_entities: bool = False,
export_events: bool = True,
cleanup_users: bool = True,
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
@@ -656,18 +730,19 @@ class PrePopulate:
if not items:
continue
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
if export_events:
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
cls._cleanup_entity(cls_, item)
cls._cleanup_entity(cls_, item, cleanup_users=cleanup_users)
w.write(item.to_json())
data = f.getvalue()
hash_.update(data)
@@ -717,7 +792,10 @@ class PrePopulate:
@classmethod
def _generate_new_ids(
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
cls,
reader: ZipFile,
entity_files: Sequence,
metadata: Mapping[str, Any],
) -> Mapping[str, str]:
if not metadata or not any(
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
@@ -745,6 +823,68 @@ class PrePopulate:
)
return ids
@classmethod
def _import_users(cls, reader: ZipFile, company_id: str = "") -> dict:
"""
Import users to db and return the mapping of old user ids to the new ones
If no users were in the users file then the mapping was empty
If the user in the file has the same email as one of the existing ones then this user is skipped
and its id is mapped to the existing user with the same email
If the user with the same id exists in backend or auth db then its creation is skipped
"""
users_file = first(
fi for fi in reader.filelist if fi.orig_filename == cls.users_filename
)
if not users_file:
return {}
existing_user_ids = set(cls.user_cls.objects().scalar("id")) | set(
cls.auth_user_cls.objects().scalar("id")
)
existing_user_emails = {u.email: u.id for u in cls.auth_user_cls.objects()}
user_id_mappings = {}
with reader.open(users_file) as f:
data = json.loads(f.read())
auth_users = {u["_id"]: u for u in data["auth"]}
be_users = {u["_id"]: u for u in data["backend"]}
for uid, user in auth_users.items():
email = user.get("email")
existing_user_id = existing_user_emails.get(email)
if existing_user_id:
user_id_mappings[uid] = existing_user_id
continue
user_id_mappings[uid] = uid
if uid in existing_user_ids:
continue
credentials = user.get("credentials", [])
for c in credentials:
if c.get("company") == "":
c["company"] = company_id
if hasattr(cls.auth_user_cls, "sec_groups"):
user_role = user.get("role", Role.user)
if user_role == Role.user:
user["sec_groups"] = ["30795571-a470-4717-a80d-e8705fc776bf"]
else:
user["sec_groups"] = [
"c14a3cc6-1144-4896-8ea6-fb186ee19896",
"30795571-a470-4717-a80d-e8705fc776bf",
"30795571a4704717a80de8705897ytuyg",
]
auth_user = cls.auth_user_cls.from_json(json.dumps(user), created=True)
auth_user.company = company_id
auth_user.save()
be_user = cls.user_cls.from_json(json.dumps(be_users[uid]), created=True)
be_user.company = company_id
be_user.save()
return user_id_mappings
@classmethod
def _import(
cls,
@@ -753,6 +893,7 @@ class PrePopulate:
user_id: str = None,
metadata: Mapping[str, Any] = None,
sort_tasks_by_last_updated: bool = True,
user_mapping: Mapping[str, str] = None,
):
"""
Import entities and events from the zip file
@@ -763,7 +904,7 @@ class PrePopulate:
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
and fi.orig_filename not in (cls.metadata_filename, cls.users_filename)
]
metadata = metadata or {}
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
@@ -773,7 +914,13 @@ class PrePopulate:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(
f, full_name, company_id, user_id, metadata, old_to_new_ids
f,
full_name=full_name,
company_id=company_id,
user_id=user_id,
metadata=metadata,
old_to_new_ids=old_to_new_ids,
user_mapping=user_mapping,
)
if res:
tasks = res
@@ -794,7 +941,7 @@ class PrePopulate:
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, company_id, user_id, task.id)
cls._import_events(f, company_id, task.user, task.id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
@@ -874,7 +1021,7 @@ class PrePopulate:
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = models.get(type_, [])
new_models = [m for m in models.get(type_, []) if m.get("model") is not None]
name = TaskModelNames[type_]
if old_model and not any(
m
@@ -908,7 +1055,9 @@ class PrePopulate:
user_id: str,
metadata: Mapping[str, Any],
old_to_new_ids: Mapping[str, str] = None,
user_mapping: Mapping[str, str] = None,
) -> Optional[Sequence[Task]]:
user_mapping = user_mapping or {}
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
@@ -930,7 +1079,7 @@ class PrePopulate:
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
doc.user = user_mapping.get(doc.user, user_id) if doc.user else user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, cls.project_cls):
@@ -970,7 +1119,7 @@ class PrePopulate:
ev["allow_locked"] = True
cls.event_bll.add_events(
company_id=company_id,
user_id=user_id,
identity=Identity(user_id, company=company_id, role=Role.admin),
events=events,
worker="",
)

View File

@@ -26,6 +26,7 @@ def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: boo
credentials = [] if revoke else [creds]
user_id = user_data.get("id", f"__{user_data['name']}__")
autocreated = user_data.get("autocreated", False)
log.info(f"Creating user: {user_data['name']}")
@@ -37,6 +38,7 @@ def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: boo
email=user_data["email"],
created=datetime.utcnow(),
credentials=credentials,
autocreated=autocreated,
)
user.save()
@@ -59,7 +61,7 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
return user_id
def ensure_fixed_user(user: FixedUser, log: Logger):
def ensure_fixed_user(user: FixedUser, log: Logger, emails: set):
db_user = User.objects(company=user.company, id=user.user_id).first()
if db_user:
# noinspection PyBroadException
@@ -73,9 +75,12 @@ def ensure_fixed_user(user: FixedUser, log: Logger):
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
email = f"{user.user_id}@example.com"
data["email"] = email
data["role"] = Role.guest if user.is_guest else Role.user
data["autocreated"] = True
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
emails.add(email)
return _ensure_backend_user(user.user_id, user.company, user.name)

View File

@@ -6,11 +6,11 @@ boto3>=1.26
boto3-stubs[s3]>=1.26
clearml>=1.10.3
dpath>=1.4.2,<2.0
elasticsearch==7.17.9
elasticsearch==8.12.0
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=2.3.2
flask>=2.3.3
furl>=2.0.0
google-cloud-storage>=2.8.0
gunicorn>=20.1.0
@@ -34,3 +34,4 @@ setuptools>=65.5.1
six
validators>=0.12.4
urllib3>=1.26.18
werkzeug>=3.0.1

View File

@@ -754,6 +754,42 @@ get_task_metrics{
}
}
}
get_multi_task_metrics {
"2.28" {
description: """Get unique metrics and variants from the events of the specified type.
Only events reported for the passed task or model ids are analyzed."""
request {
type: object
required: [ tasks ]
properties {
tasks {
description: task ids to get metrics from
type: array
items {type: string}
}
model_events {
description: If not set or set to false then passed ids are task ids otherwise model ids
type: boolean
default: false
}
event_type {
"description": Event type. If not specified then metrics are collected from the reported events of all types
"$ref": "#/definitions/event_type_enum"
}
}
}
response {
type: object
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
get_task_log {
"1.5" {
description: "Get all 'log' events for this task"
@@ -971,10 +1007,17 @@ get_task_events {
}
}
"2.22": ${get_task_events."2.1"} {
request.properties.model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
request.properties {
model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
}
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
@@ -1156,6 +1199,13 @@ get_multi_task_plots {
default: true
}
}
"2.28": ${get_multi_task_plots."2.26"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_vector_metrics_and_variants {
"2.1" {
@@ -1342,6 +1392,13 @@ multi_task_scalar_metrics_iter_histogram {
default: false
}
}
"2.28": ${multi_task_scalar_metrics_iter_histogram."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_single_value_metrics {
"2.20" {
@@ -1369,6 +1426,13 @@ get_task_single_value_metrics {
default: false
}
}
"2.28": ${get_task_single_value_metrics."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_latest_scalar_values {
"2.1" {

View File

@@ -11,16 +11,7 @@ supported_modes {
description: """ Return supported login modes."""
request {
type: object
properties {
state {
description: "ASCII base64 encoded application state"
type: string
}
callback_url_prefix {
description: "URL prefix used to generate the callback URL for each supported SSO provider"
type: string
}
}
additionalProperties: false
}
response {
type: object

View File

@@ -79,4 +79,15 @@ start_pipeline {
}
}
}
"2.28": ${start_pipeline."2.17"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}

View File

@@ -949,6 +949,13 @@ get_unique_metric_variants {
default: false
}
}
"2.28": ${get_unique_metric_variants."2.25"} {
request.properties.ids {
description: IDs of the tasks or models to get metrics from
type: array
items {type: string}
}
}
}
get_hyperparam_values {
"2.13" {

View File

@@ -21,6 +21,11 @@ log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
_custom_cookie_settings = {
c["name"]: c["settings"]
for c in config.get("apiserver.auth.custom_cookies", {}).values()
if c.get("enabled") and c.get("settings")
}
def before_request(self):
if request.method == "OPTIONS":
@@ -29,7 +34,10 @@ class RequestHandlers:
return
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
return (
f"Content encoding is not supported ({request.content_encoding})",
415,
)
try:
call = self._create_api_call(request)
@@ -42,7 +50,10 @@ class RequestHandlers:
response = redirect(call.result.redirect.url, call.result.redirect.code)
else:
headers = None
disable_cache = False
if call.result.filename:
# make sure that downloaded files are not cached by the client
disable_cache = True
try:
call.result.filename.encode("ascii")
except UnicodeEncodeError:
@@ -61,10 +72,16 @@ class RequestHandlers:
status=call.result.code,
headers=headers,
)
if disable_cache:
response.cache_control.no_store = True
response.cache_control.max_age = 0
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies").copy()
kwargs = (
self._custom_cookie_settings.get(key)
or config.get("apiserver.auth.cookies")
).copy()
if value is None:
# Removing a cookie
kwargs["max_age"] = 0
@@ -81,7 +98,9 @@ class RequestHandlers:
if company:
try:
# use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
kwargs["domain"] = config.get(
f"apiserver.auth.cookies_domain_override.{company}"
)
except KeyError:
pass
@@ -108,11 +127,15 @@ class RequestHandlers:
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
v = (
[convert_value(x) for x in v]
if (len(v) > 1 or k.endswith("[]"))
else convert_value(v[0])
)
nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
"""Use request payload/form to fill call data or batched data"""
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
@@ -142,6 +165,9 @@ class RequestHandlers:
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def _get_session_auth_cookie(self, req):
return req.cookies.get(config.get("apiserver.auth.session_auth_cookie_name"))
def _create_api_call(self, req):
call = None
try:
@@ -155,9 +181,7 @@ class RequestHandlers:
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
auth_cookie = self._get_session_auth_cookie(req)
headers = (
{}
if not auth_cookie

View File

@@ -1,4 +1,4 @@
from .auth import get_auth_func, authorize_impersonation
from .auth import get_auth_func, authorize_impersonation, revoke_auth_token
from .payload import Token, Basic, AuthType, Payload
from .identity import Identity
from .utils import get_client_id, get_secret_key

View File

@@ -1,5 +1,6 @@
import base64
from datetime import datetime
from time import time
import bcrypt
import jwt
@@ -11,15 +12,16 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.redis_manager import redman
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
log = config.logger(__file__)
entity_keys = set(get_options(Entities))
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
_revoked_tokens_key = "revoked_tokens"
redis = redman.connection("apiserver")
def get_auth_func(auth_type):
@@ -41,8 +43,10 @@ def authorize_token(jwt_token, service, action, call):
log.error(f"{msg} Call info: {info}")
try:
return Token.from_encoded_token(jwt_token)
token = Token.from_encoded_token(jwt_token)
if is_token_revoked(token):
raise errors.unauthorized.InvalidToken("revoked token")
return token
except jwt.exceptions.InvalidKeyError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken(
@@ -154,3 +158,23 @@ def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
return bcrypt.checkpw(
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
)
def is_token_revoked(token: Token) -> bool:
if not isinstance(token, Token) or not token.session_id:
return False
return redis.zscore(_revoked_tokens_key, token.session_id) is not None
def revoke_auth_token(token: Token):
if not isinstance(token, Token) or not token.session_id:
return
timestamp_now = int(time())
expiration_timestamp = token.exp
if not expiration_timestamp:
expiration_timestamp = timestamp_now + Token.default_expiration_sec
redis.zadd(_revoked_tokens_key, {token.session_id: expiration_timestamp})
redis.zremrangebyscore(_revoked_tokens_key, min=0, max=timestamp_now)

View File

@@ -1,3 +1,5 @@
from uuid import uuid4
import jwt
from datetime import datetime, timedelta
@@ -20,7 +22,15 @@ class Token(Payload):
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
def __init__(
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
self,
exp=None,
iat=None,
nbf=None,
env=None,
identity=None,
session_id=None,
entities=None,
**_,
):
super(Token, self).__init__(
AuthType.bearer_token, identity=identity, entities=entities
@@ -28,8 +38,13 @@ class Token(Payload):
self.exp = exp
self.iat = iat
self.nbf = nbf
self._session_id = session_id
self._env = env or config.get("env", "<unknown>")
@property
def session_id(self):
return self._session_id
@property
def env(self):
return self._env
@@ -102,8 +117,11 @@ class Token(Payload):
expiration_sec = expiration_sec or cls.default_expiration_sec
now = datetime.utcnow()
session_id = uuid4().hex
token = cls(identity=identity, entities=entities, iat=now)
token = cls(
identity=identity, entities=entities, iat=now, session_id=session_id
)
if expiration_sec:
# add 'expiration' claim

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.27")
_max_version = PartialVersion("2.29")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -24,6 +24,7 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Role
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Token
from apiserver.service_repo.auth.auth import is_token_revoked, revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
log = config.logger(__file__)
@@ -35,7 +36,7 @@ log = config.logger(__file__)
response_data_model=GetTokenResponse,
)
def login(call: APICall, *_, **__):
""" Generates a token based on the authenticated user (intended for use with credentials) """
"""Generates a token based on the authenticated user (intended for use with credentials)"""
call.result.data_model = AuthBLL.get_token_for_user(
user_id=call.identity.user,
company_id=call.identity.company,
@@ -48,6 +49,7 @@ def login(call: APICall, *_, **__):
@endpoint("auth.logout", min_version="2.2")
def logout(call: APICall, *_, **__):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)
@@ -57,7 +59,7 @@ def logout(call: APICall, *_, **__):
response_data_model=GetTokenResponse,
)
def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
""" Generates a token based on a requested user and company. INTERNAL. """
"""Generates a token based on a requested user and company. INTERNAL."""
if call.identity.role not in Role.get_system_roles():
if call.identity.role != Role.admin and call.identity.user != request.user:
raise errors.bad_request.InvalidUserId(
@@ -81,12 +83,14 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
response_data_model=ValidateResponse,
)
def validate_token_endpoint(call: APICall, _, __):
""" Validate a token and return identity if valid. INTERNAL. """
"""Validate a token and return identity if valid. INTERNAL."""
try:
# if invalid, decoding will fail
token = Token.from_encoded_token(call.data_model.token)
call.result.data_model = ValidateResponse(
valid=True, user=token.identity.user, company=token.identity.company
valid=not is_token_revoked(token),
user=token.identity.user,
company=token.identity.company,
)
except Exception as e:
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
@@ -98,7 +102,7 @@ def validate_token_endpoint(call: APICall, _, __):
response_data_model=CreateUserResponse,
)
def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """
"""Create a user from. INTERNAL."""
if (
call.identity.role not in Role.get_system_roles()
and request.company != call.identity.company

View File

@@ -31,6 +31,15 @@ from apiserver.apimodels.events import (
GetMetricSamplesRequest,
TaskMetric,
MultiTaskPlotsRequest,
MultiTaskMetricsRequest,
LegacyLogEventsRequest,
TaskRequest,
GetMetricsAndVariantsRequest,
ModelRequest,
LegacyMetricEventsRequest,
GetScalarMetricDataRequest,
VectorMetricsIterHistogramRequest,
LegacyMultiTaskEventsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
@@ -38,6 +47,7 @@ from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -73,7 +83,7 @@ def add(call: APICall, company_id, _):
data = call.data.copy()
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
events=[data],
worker=call.worker,
)
@@ -88,22 +98,22 @@ def add_batch(call: APICall, company_id, _):
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
events=events,
worker=call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log")
def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
order = request.order
scroll_id = request.scroll_id
batch_size = request.batch_size
events, scroll_id, total_events = event_bll.scroll_task_events(
task.get_index_company(),
task_id,
@@ -117,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _):
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log", min_version="1.7")
def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
order = request.order
from_ = call.data.get("from") or "head"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_id = request.scroll_id
batch_size = request.batch_size
scroll_order = "asc" if (from_ == "head") else "desc"
@@ -175,9 +185,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.download_task_log")
def download_task_log(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@@ -255,10 +265,12 @@ def download_task_log(call, company_id, _):
call.result.raw_data = generate()
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_vector_metrics_and_variants")
def get_vector_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
@@ -271,10 +283,12 @@ def get_vector_metrics_and_variants(call, company_id, _):
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_scalar_metrics_and_variants")
def get_scalar_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
@@ -290,18 +304,19 @@ def get_scalar_metrics_and_variants(call, company_id, _):
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint(
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
def vector_metrics_iter_histogram(
call, company_id, request: VectorMetricsIterHistogramRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
model_events=model_events,
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
metric = request.metric
variant = request.variant
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task_or_model.get_index_company(), task_id, metric, variant
)
@@ -402,13 +417,13 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
)
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
@endpoint("events.get_scalar_metric_data")
def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest):
task_id = request.task
metric = request.metric
scroll_id = request.scroll_id
no_scroll = request.no_scroll
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
@@ -433,9 +448,9 @@ def get_scalar_metric_data(call, company_id, _):
)
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_latest_scalar_values")
def get_task_latest_scalar_values(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@@ -521,6 +536,7 @@ def multi_task_scalar_metrics_iter_histogram(
),
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
)
@@ -548,17 +564,18 @@ def get_task_single_value_metrics(
tasks=_get_single_value_metrics_response(
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
companies=companies,
metric_variants=_get_metric_variants_from_request(request.metrics),
),
)
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_multi_task_plots")
def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest):
task_ids = request.tasks
iters = request.iters
scroll_id = request.scroll_id
companies = _get_task_or_model_index_companies(company_id, task_ids)
@@ -591,10 +608,11 @@ def _get_multitask_plots(
companies: TaskCompanies,
last_iters: int,
last_iters_per_task_metric: bool,
metrics: MetricVariants = None,
request_metrics: Sequence[ApiMetrics] = None,
scroll_id=None,
no_scroll=True,
) -> Tuple[dict, int, str]:
metrics = _get_metric_variants_from_request(request_metrics)
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
@@ -629,6 +647,7 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
scroll_id=request.scroll_id,
no_scroll=request.no_scroll,
last_iters_per_task_metric=request.last_iters_per_task_metric,
request_metrics=request.metrics,
)
call.result.data = dict(
plots=return_events,
@@ -638,11 +657,11 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
)
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_task_plots")
def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -760,11 +779,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
@endpoint("events.debug_images")
def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -797,12 +816,12 @@ def get_debug_images_v1_7(call, company_id, _):
)
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
model_events = call.data.get("model_events", False)
@endpoint("events.debug_images", min_version="1.8")
def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
model_events = request.model_events
tasks_or_model = _assert_task_or_model_exists(
company_id,
@@ -960,12 +979,35 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_multi_task_metrics")
def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsRequest):
companies = _get_task_or_model_index_companies(
company_id, request.tasks, model_events=request.model_events
)
if not companies:
return {"metrics": []}
metrics = event_bll.metrics.get_multi_task_metrics(
companies=companies, event_type=request.event_type
)
res = [
{
"metric": m,
"variants": sorted(vars_),
}
for m, vars_ in metrics.items()
]
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
@endpoint("events.delete_for_task")
def delete_for_task(call, company_id, request: TaskRequest):
task_id = request.task
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked
@@ -973,9 +1015,9 @@ def delete_for_task(call, company_id, _):
)
@endpoint("events.delete_for_model", required_fields=["model"])
def delete_for_model(call: APICall, company_id: str, _):
model_id = call.data["model"]
@endpoint("events.delete_for_model")
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
model_id = request.model
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)
@@ -990,7 +1032,9 @@ def delete_for_model(call: APICall, company_id: str, _):
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,

View File

@@ -7,6 +7,7 @@ from apiserver.apimodels.login import (
)
from apiserver.config import info
from apiserver.service_repo import endpoint, APICall
from apiserver.service_repo.auth import revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
@@ -37,4 +38,5 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
@endpoint("login.logout", min_version="2.13")
def logout(call: APICall, _, __):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)

View File

@@ -21,6 +21,10 @@ from apiserver.apimodels.models import (
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
ModelsGetRequest,
ModelRequest,
TaskRequest,
UpdateForTaskRequest,
UpdateModelRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata
@@ -28,6 +32,7 @@ from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.model import validate_id
@@ -46,6 +51,7 @@ from apiserver.database.utils import (
filter_fields,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -65,9 +71,9 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
unescape_metadata(call, model_data)
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
@endpoint("models.get_by_id")
def get_by_id(call: APICall, company_id, request: ModelRequest):
model_id = request.model
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many(
company=company_id,
@@ -85,12 +91,12 @@ def get_by_id(call: APICall, company_id, _):
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call: APICall, company_id, _):
@endpoint("models.get_by_task_id")
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
task_id = call.data["task"]
task_id = request.task
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
@@ -155,7 +161,7 @@ def get_by_id_ex(call: APICall, company_id, _):
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
@endpoint("models.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call.data)
@@ -191,7 +197,7 @@ create_fields = {
"project": Project,
"parent": Model,
"framework": None,
"design": None,
"design": dict,
"labels": dict,
"ready": None,
"metadata": list,
@@ -234,28 +240,27 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
)
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _):
@endpoint("models.update_for_task")
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
override_model_id = call.data.get("override_model_id")
task_id = request.task
uri = request.uri
iteration = request.iteration
override_model_id = request.override_model_id
if not (uri or override_model_id) or (uri and override_model_id):
raise errors.bad_request.MissingRequiredFields(
"exactly one field is required", fields=("uri", "override_model_id")
)
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("models", "execution", "name", "status", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
@@ -343,7 +348,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
validate_task(company_id, call.identity, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
@@ -373,7 +378,7 @@ def prepare_update_fields(call, company_id, fields: dict):
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(company_id, fields)
validate_task(company_id, call.identity, fields)
if "labels" in fields:
labels = fields["labels"]
@@ -403,13 +408,16 @@ def prepare_update_fields(call, company_id, fields: dict):
return fields
def validate_task(company_id, fields: dict):
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
def validate_task(company_id: str, identity: Identity, fields: dict):
task_id = fields["task"]
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
)
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
@endpoint("models.edit", response_data_model=UpdateResponse)
def edit(call: APICall, company_id, request: UpdateModelRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
@@ -424,7 +432,7 @@ def edit(call: APICall, company_id, _):
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
iteration = request.iteration
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
@@ -456,13 +464,9 @@ def edit(call: APICall, company_id, _):
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
def _update_model(call: APICall, company_id, model_id):
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
@@ -498,11 +502,9 @@ def _update_model(call: APICall, company_id, model_id=None):
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call, company_id, _):
call.result.data_model = _update_model(call, company_id)
@endpoint("models.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UpdateModelRequest):
call.result.data_model = _update_model(call, company_id, model_id=request.model)
@endpoint(
@@ -514,7 +516,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
)
@@ -533,7 +535,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
func=partial(
ModelBLL.publish_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),
@@ -625,7 +627,9 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
func=partial(
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
)
call.result.data_model = BatchResponse(

View File

@@ -5,22 +5,24 @@ import attr
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
from apiserver.apimodels.pipelines import (
StartPipelineResponse,
StartPipelineRequest,
DeleteRunsRequest,
)
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task, delete_task
from apiserver.bll.util import run_batch_operation
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
queue_bll = QueueBLL()
def _update_task_name(task: Task):
@@ -57,7 +59,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=False,
force=True,
return_file_urls=False,
@@ -79,9 +81,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
call.result.data = dict(succeeded=succeeded, failed=failures)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
@endpoint("pipelines.start_pipeline")
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
@@ -108,10 +108,19 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",
)
extra = {}
if request.verify_watched_queue and queued:
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
call.result.data = dict(
pipeline=task.id,
enqueued=bool(queued),
**extra,
)

View File

@@ -59,13 +59,12 @@ create_fields = {
}
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
@endpoint("projects.get_by_id")
def get_by_id(call: APICall, company: str, request: ProjectRequest):
project_id = request.project
with translate_errors_context():
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
query = Q(id=project_id) & get_company_or_none_constraint(company)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@@ -147,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
data,
shallow_search=request.shallow_search,
)
selected_project_ids = None
if request.active_users or request.children_type:
@@ -246,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(
company=company_id, project_ids=project_ids, users=request.active_users,
company=company_id,
project_ids=project_ids,
users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
@@ -255,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
@endpoint("projects.get_all")
def get_all(call: APICall):
def get_all(call: APICall, company: str, _):
data = call.data
conform_tag_fields(call, data)
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
data,
shallow_search=data.get("shallow_search", False),
)
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
company=company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
@@ -277,9 +281,11 @@ def get_all(call: APICall):
@endpoint(
"projects.create", required_fields=["name"], response_data_model=IdResponse,
"projects.create",
required_fields=["name"],
response_data_model=IdResponse,
)
def create(call: APICall):
def create(call: APICall, company: str, _):
identity = call.identity
with translate_errors_context():
@@ -288,15 +294,15 @@ def create(call: APICall):
return IdResponse(
id=ProjectBLL.create(
user=identity.user, company=identity.company, **fields,
user=identity.user,
company=company,
**fields,
)
)
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call: APICall):
@endpoint("projects.update", response_data_model=UpdateResponse)
def update(call: APICall, company: str, request: ProjectRequest):
"""
update
@@ -309,9 +315,7 @@ def update(call: APICall):
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields, validate=True)
updated = ProjectBLL.update(
company=call.identity.company, project_id=call.data["project"], **fields
)
updated = ProjectBLL.update(company=company, project_id=request.project, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -375,11 +379,11 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
def get_unique_metric_variants(
call: APICall, company_id: str, request: GetUniqueMetricsRequest
):
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
ids=request.ids,
model_metrics=request.model_metrics,
)
@@ -428,7 +432,6 @@ def get_model_metadata_values(
request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,

View File

@@ -19,7 +19,9 @@ from apiserver.apimodels.reports import (
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.database.model.model import Model
from apiserver.service_repo.auth import Identity
from apiserver.services.models import conform_model_data
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
@@ -57,15 +59,15 @@ update_fields = {
}
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
def _assert_report(company_id: str, task_id: str, identity: Identity, only_fields=None):
if only_fields and "type" not in only_fields:
only_fields += ("type",)
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=identity,
only=only_fields,
requires_write_access=requires_write_access,
)
if task.type != TaskType.report:
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
@@ -78,6 +80,7 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task,
company_id=company_id,
identity=call.identity,
only_fields=("status",),
)
@@ -265,7 +268,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
res["plots"] = _get_multitask_plots(
companies=companies,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
request_metrics=request.plots.metrics,
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
)[0]
@@ -302,6 +305,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
task = _assert_report(
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
user_id = call.identity.user
@@ -337,7 +341,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
updates = ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
@@ -352,7 +358,9 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
@endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
archived = task.update(
status_message=request.message,
status_reason="",
@@ -366,7 +374,9 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
unarchived = task.update(
status_message=request.message,
status_reason="",
@@ -394,6 +404,7 @@ def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
if (

View File

@@ -3,7 +3,11 @@ from datetime import datetime
from pyhocon.config_tree import NoneValue
from apiserver.apierrors import errors
from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
from apiserver.apimodels.server import (
ReportStatsOptionRequest,
ReportStatsOptionResponse,
GetConfigRequest,
)
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
from apiserver.config_repo import config
from apiserver.config.info import get_version, get_build_number, get_commit_number
@@ -22,8 +26,8 @@ def get_stats(call: APICall):
@endpoint("server.config")
def get_config(call: APICall):
path = call.data.get("path")
def get_config(call: APICall, _, request: GetConfigRequest):
path = request.path
if path:
c = dict(config.get(path))
else:

View File

@@ -100,10 +100,17 @@ from apiserver.bll.task.task_operations import (
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.task.utils import (
update_task,
get_task_for_update,
deleted_prefix,
get_many_tasks_for_writing,
get_task_with_write_access,
)
from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
@@ -118,6 +125,7 @@ from apiserver.database.utils import (
get_options,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -142,14 +150,34 @@ org_bll = OrgBLL()
project_bll = ProjectBLL()
def _assert_writable_tasks(
company_id: str, identity: Identity, ids: Sequence[str], only=("id",)
) -> Sequence[Task]:
tasks = get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=Q(id__in=ids),
only=only,
)
missing_ids = set(ids) - {t.id for t in tasks}
if missing_ids:
raise errors.bad_request.InvalidTaskId(ids=list(missing_ids))
return tasks
def set_task_status_from_call(
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
request: UpdateRequest,
company_id: str,
identity: Identity,
new_status=None,
**set_fields,
) -> dict:
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
request.task,
company_id=company_id,
identity=identity,
only=("id", "status", "project"),
requires_write_access=True,
)
status_reason = request.status_reason
@@ -161,15 +189,17 @@ def set_task_status_from_call(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
user_id=identity.user,
).execute(**set_fields)
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, allow_public=True
)
def get_by_id(call: APICall, company_id, request: TaskRequest):
task = TaskBLL.assert_exists(
company_id,
task_ids=request.task,
allow_public=True,
)[0]
task_dict = task.to_proper_dict()
conform_task_data(call, task_dict)
call.result.data = {"task": task_dict}
@@ -227,14 +257,16 @@ def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call.data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
company=company_id,
query_dict=call_data,
allow_public=True,
)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
@endpoint("tasks.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call.data)
@@ -278,7 +310,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
**stop_task(
task_id=req_model.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
@@ -296,7 +328,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
func=partial(
stop_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
@@ -319,7 +351,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.stopped,
completed=datetime.utcnow(),
)
@@ -332,13 +364,21 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
response_data_model=StartedResponse,
)
def started(call: APICall, company_id, req_model: UpdateRequest):
started_update = {}
if Task.objects(id=req_model.task, started=None).only("id"):
# this is the fix for older versions putting started to None on reset
started_update["started"] = datetime.utcnow()
else:
# don't override a previous, smaller "started" field value
started_update["min__started"] = datetime.utcnow()
res = StartedResponse(
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
**started_update,
)
)
res.started = res.updated
@@ -353,7 +393,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.failed,
)
)
@@ -367,7 +407,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.closed,
)
)
@@ -381,18 +421,19 @@ create_fields = {
"error": None,
"comment": None,
"parent": Task,
"project": None,
"project": Project,
"input": None,
"models": None,
"container": None,
"container": dict,
"output_dest": None,
"execution": None,
"hyperparams": None,
"configuration": None,
"hyperparams": dict,
"configuration": dict,
"script": None,
"runtime": None,
"runtime": dict,
}
dict_fields_paths = [("execution", "model_labels"), "container"]
@@ -433,13 +474,17 @@ def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
for data in tasks_data:
params_unprepare_from_saved(
fields=data, copy_to_legacy=need_legacy_params,
fields=data,
copy_to_legacy=need_legacy_params,
)
artifacts_unprepare_from_saved(fields=data)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
call: APICall,
valid_fields=None,
output=None,
previous_task: Task = None,
):
valid_fields = valid_fields if valid_fields is not None else create_fields
t_fields = task_fields
@@ -566,11 +611,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
task_id = req_model.task
with translate_errors_context():
task = Task.get_for_writing(
id=task_id, company=company_id, _only=["id", "project"]
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("id", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
@@ -582,7 +628,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
id=task_id,
partial_update_dict=partial_update_dict,
injected_update=dict(
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
),
)
if updated_count:
@@ -606,11 +653,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
requirements = req_model.requirements
with translate_errors_context():
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
req_model.task,
company_id=company_id,
identity=call.identity,
only=("status", "script"),
requires_write_access=True,
)
if not task.script:
raise errors.bad_request.MissingTaskFields(
@@ -636,8 +683,11 @@ def update_batch(call: APICall, company_id, _):
items = {i["task"]: i for i in items}
tasks = {
t.id: t
for t in Task.get_many_for_writing(
company=company_id, query=Q(id__in=list(items))
for t in _assert_writable_tasks(
identity=call.identity,
company_id=company_id,
ids=list(items),
only=("id", "project"),
)
}
@@ -656,7 +706,8 @@ def update_batch(call: APICall, company_id, _):
if not partial_update_dict:
continue
partial_update_dict.update(
last_change=now, last_changed_by=call.identity.user,
last_change=now,
last_changed_by=call.identity.user,
)
update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
@@ -690,9 +741,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
force = req_model.force
with translate_errors_context():
task = Task.get_for_writing(id=task_id, company=company_id)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
)
if not force and task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
@@ -756,7 +809,8 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
@endpoint(
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
"tasks.get_hyper_params",
request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
@@ -771,7 +825,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
@@ -785,7 +839,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
@@ -794,7 +848,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
"tasks.get_configurations",
request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
tasks_params = HyperParams.get_configurations(
@@ -809,7 +864,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
@endpoint(
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
"tasks.get_configuration_names",
request_data_model=GetConfigurationNamesRequest,
)
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
@@ -830,7 +886,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
@@ -846,7 +902,7 @@ def delete_configuration(
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
force=request.force,
@@ -863,7 +919,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -888,7 +944,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
func=partial(
enqueue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -915,13 +971,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
@endpoint(
"tasks.dequeue", response_data_model=DequeueResponse,
"tasks.dequeue",
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, request: DequeueRequest):
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
@@ -931,14 +988,15 @@ def dequeue(call: APICall, company_id, request: DequeueRequest):
@endpoint(
"tasks.dequeue_many", response_data_model=DequeueManyResponse,
"tasks.dequeue_many",
response_data_model=DequeueManyResponse,
)
def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
results, failures = run_batch_operation(
func=partial(
dequeue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
@@ -962,7 +1020,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -990,7 +1048,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
func=partial(
reset_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -1027,9 +1085,11 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
response_data_model=ArchiveResponse,
)
def archive(call: APICall, company_id, request: ArchiveRequest):
tasks = TaskBLL.assert_exists(
archived = 0
tasks = _assert_writable_tasks(
company_id,
task_ids=request.tasks,
call.identity,
ids=request.tasks,
only=(
"id",
"company",
@@ -1040,11 +1100,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"enqueue_status",
),
)
archived = 0
for task in tasks:
archived += archive_task(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -1063,7 +1122,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
archive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1085,7 +1144,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
unarchive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1104,7 +1163,7 @@ def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1126,7 +1185,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1164,7 +1223,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason,
@@ -1183,7 +1242,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
func=partial(
publish_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
@@ -1211,7 +1270,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
**set_task_status_from_call(
request,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.completed,
completed=datetime.utcnow(),
)
@@ -1221,7 +1280,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
@@ -1256,7 +1315,7 @@ def add_or_update_artifacts(
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifacts=request.artifacts,
force=True,
@@ -1273,7 +1332,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifact_ids=request.artifacts,
force=True,
@@ -1310,6 +1369,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required"
)
_assert_writable_tasks(company_id, call.identity, request.ids)
updated_projects = set(
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
)
@@ -1330,7 +1390,8 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("tasks.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
_assert_writable_tasks(company_id, call.identity, request.ids)
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
@@ -1344,7 +1405,9 @@ def update_tags(_, company_id: str, request: UpdateTagsRequest):
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
models_field = f"models__{request.type}"
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
@@ -1364,7 +1427,9 @@ def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelR
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
task = get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
delete_names = {
type_: [m.name for m in request.models if m.type == type_]
@@ -1377,6 +1442,8 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
}
updated = task.update(
last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
**commands,
)
return {"updated": updated}

View File

@@ -7,12 +7,16 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest
from apiserver.apimodels.users import (
CreateRequest,
SetPreferencesRequest,
UserRequest,
)
from apiserver.bll.project import ProjectBLL
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import Role
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.company import Company
from apiserver.database.model.user import User
from apiserver.database.utils import parse_from_call
@@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None):
return res.to_proper_dict()
@endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call: APICall, company_id, _):
user_id = call.data["user"]
@endpoint("users.get_by_id")
def get_by_id(call: APICall, company_id, request: UserRequest):
user_id = request.user
call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[])
@endpoint("users.get_all_ex")
def get_all_ex(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many_with_join(company=company_id, query_dict=call.data)
@@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
@endpoint("users.get_all_ex", min_version="2.8")
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
@@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[])
@endpoint("users.get_all")
def get_all(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many(
@@ -138,9 +142,9 @@ def create(call: APICall):
UserBLL.create(call.data_model)
@endpoint("users.delete", required_fields=["user"])
def delete(call: APICall):
UserBLL.delete(call.data["user"])
@endpoint("users.delete")
def delete(_: APICall, __, request: UserRequest):
UserBLL.delete(request.user)
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
@@ -154,14 +158,22 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
update_fields = {
k: v for k, v in create_fields.items() if k in User.user_set_allowed()
}
auth_user_update_fields = ("name",)
partial_update_dict = parse_from_call(data, update_fields, User.get_fields())
with translate_errors_context("updating user"):
return User.safe_update(company_id, user_id, partial_update_dict)
ret = User.safe_update(company_id, user_id, partial_update_dict)
auth_update = {
k: v for k, v in partial_update_dict.items() if k in auth_user_update_fields
}
if auth_update:
AuthUser.objects(id=user_id).update(**auth_update)
return ret
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _):
user_id = call.data["user"]
@endpoint("users.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UserRequest):
user_id = request.user
update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)

View File

@@ -3,6 +3,30 @@ from apiserver.tests.automated import TestService
class TestGetAllExFilters(TestService):
def test_no_tags_filter(self):
task = self._temp_task(tags=["test"])
task_no_tags = self._temp_task()
tasks = [task, task_no_tags]
for cond, op, tags, expected_tasks in (
("any", "include", [None], [task_no_tags]),
("any", "include", ["test"], [task]),
("any", "include", ["test", None], [task, task_no_tags]),
("any", "exclude", [None], [task]),
("any", "exclude", ["test"], [task_no_tags]),
("any", "exclude", ["test", None], [task, task_no_tags]),
("all", "include", [None], [task_no_tags]),
("all", "include", ["test"], [task]),
("all", "include", ["test", None], []),
("all", "exclude", [None], [task]),
("all", "exclude", ["test"], [task_no_tags]),
("all", "exclude", ["test", None], []),
):
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {cond: {op: tags}}}
).tasks
self.assertEqual({t.id for t in res}, set(expected_tasks))
def test_list_filters(self):
tags = ["a", "b", "c", "d"]
tasks = [self._temp_task(tags=tags[:i]) for i in range(len(tags) + 1)]

View File

@@ -37,29 +37,44 @@ class TestPipelines(TestService):
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
pipeline_task = res.pipeline
try:
self.assertTrue(res.enqueued)
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
self.assertTrue(pipeline.name.startswith(task_name))
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
)
finally:
self.api.tasks.delete(task=pipeline_task, force=True)
self.assertTrue(res.enqueued)
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
self.assertTrue(pipeline.name.startswith(task_name))
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
)
# watched queue
queue = self._temp_queue("test pipelines")
project, task = self._temp_project_and_task(name="pipelines test1")
res = self.api.pipelines.start_pipeline(
task=task, queue=queue, verify_watched_queue=True
)
self.assertEqual(res.queue_watched, False)
self.api.workers.register(worker="test pipelines", queues=[queue])
project, task = self._temp_project_and_task(name="pipelines test2")
res = self.api.pipelines.start_pipeline(
task=task, queue=queue, verify_watched_queue=True
)
self.assertEqual(res.queue_watched, True)
def _temp_project_and_task(self, name) -> Tuple[str, str]:
project = self.create_temp(
"projects", name=name, description="test", delete_params=dict(force=True),
"projects",
name=name,
description="test",
delete_params=dict(force=True, delete_contents=True),
)
return (
@@ -72,3 +87,6 @@ class TestPipelines(TestService):
system_tags=["pipeline"],
),
)
def _temp_queue(self, queue_name, **kwargs):
return self.create_temp("queues", name=queue_name, **kwargs)

View File

@@ -113,7 +113,7 @@ class TestProjectTags(TestService):
new_tags = ["New model tag"]
self.api.models.update_tags(ids=[model], add_tags=new_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
self.assertEqual(set(data.tags), {*new_tags, *initial_tags})
def new_task(self, **kwargs):
self.update_missing(

View File

@@ -16,10 +16,18 @@ class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
default_task_name = "test task events"
def _temp_task(self, name=default_task_name):
task_input = dict(name=name, type="training",)
def _temp_project(self, name=default_task_name):
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input
"projects",
name=name,
description="test",
delete_params=self.delete_params,
)
def _temp_task(self, name=default_task_name, **kwargs):
self.update_missing(kwargs, name=name, type="training")
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **kwargs
)
def _temp_model(self, name="test model events", **kwargs):
@@ -62,6 +70,26 @@ class TestTaskEvents(TestService):
self._assert_task_metrics(tasks, "log")
self._assert_task_metrics(tasks, "training_stats_scalar")
self._assert_multitask_metrics(
tasks=list(tasks), metrics=["Metric1", "Metric2", "Metric3"]
)
self._assert_multitask_metrics(
tasks=list(tasks),
event_type="training_debug_image",
metrics=["Metric1", "Metric2", "Metric3"],
)
self._assert_multitask_metrics(tasks=list(tasks), event_type="plot", metrics=[])
def _assert_multitask_metrics(
self, tasks: Sequence[str], metrics: Sequence[str], event_type: str = None
):
res = self.api.events.get_multi_task_metrics(
tasks=tasks,
**({"event_type": event_type} if event_type else {}),
).metrics
self.assertEqual([r.metric for r in res], metrics)
self.assertTrue(all(r.variants == ["Test variant"] for r in res))
def _assert_task_metrics(self, tasks: dict, event_type: str):
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
for task, metrics in tasks.items():
@@ -122,6 +150,15 @@ class TestTaskEvents(TestService):
self.assertEqual(value.metric, metric)
self.assertEqual(value.variant, variant)
self.assertEqual(value.value, 0)
# test metrics parameter
res = self.api.events.get_task_single_value_metrics(
tasks=[task], metrics=[{"metric": metric, "variants": [variant]}]
).tasks
self.assertEqual(len(res), 1)
res = self.api.events.get_task_single_value_metrics(
tasks=[task], metrics=[{"metric": "non_existing", "variants": [variant]}]
).tasks
self.assertEqual(len(res), 0)
# update is working
task_data = self.api.tasks.get_by_id(task=task).task
@@ -156,33 +193,33 @@ class TestTaskEvents(TestService):
def test_last_scalar_metrics(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
# send 2 batches to check the interaction with already stored db value
# each batch contains multiple iterations
self.send_batch(events[:50])
self.send_batch(events[50:])
for variant in ("Variant1", None):
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
# send 2 batches to check the interaction with already stored db value
# each batch contains multiple iterations
self.send_batch(events[:50])
self.send_batch(events[50:])
task_data = self.api.tasks.get_by_id(task=task).task
metric_data = first(first(task_data.last_metrics.values()).values())
self.assertEqual(iter_count - 1, metric_data.value)
self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
task_data = self.api.tasks.get_by_id(task=task).task
metric_data = first(first(task_data.last_metrics.values()).values())
self.assertEqual(iter_count - 1, metric_data.value)
self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
res = self.api.events.get_task_latest_scalar_values(task=task)
self.assertEqual(iter_count - 1, res.last_iter)
res = self.api.events.get_task_latest_scalar_values(task=task)
self.assertEqual(iter_count - 1, res.last_iter)
def test_model_events(self):
model = self._temp_model(ready=False)
@@ -248,6 +285,15 @@ class TestTaskEvents(TestService):
self._assert_log_events(task=task, expected_total=1)
metrics = self.api.events.get_multi_task_metrics(
tasks=[model],
event_type="training_stats_scalar",
model_events=True,
).metrics
self.assertEqual([m.metric for m in metrics], [f"Metric{i}" for i in range(5)])
variants = [f"Variant{i}" for i in range(5)]
self.assertTrue(all(m.variants == variants for m in metrics))
def test_error_events(self):
task = self._temp_task()
events = [
@@ -340,6 +386,30 @@ class TestTaskEvents(TestService):
else (None, None)
)
def test_task_unique_metric_variants(self):
project = self._temp_project()
task1 = self._temp_task(project=project)
task2 = self._temp_task(project=project)
metric1 = "Metric1"
metric2 = "Metric2"
events = [
{
**self._create_task_event("training_stats_scalar", task, 0),
"metric": metric,
"variant": "Variant",
"value": 10,
}
for task, metric in ((task1, metric1), (task2, metric2))
]
self.send_batch(events)
metrics = self.api.projects.get_unique_metric_variants(project=project).metrics
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
metrics = self.api.projects.get_unique_metric_variants(ids=[task1, task2]).metrics
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
metrics = self.api.projects.get_unique_metric_variants(ids=[task1]).metrics
self.assertEqual([m.metric for m in metrics], [metric1])
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"
variant = "Variant1"
@@ -395,6 +465,25 @@ class TestTaskEvents(TestService):
iterations=iter_count,
)
# test metrics
data = self.api.events.multi_task_scalar_metrics_iter_histogram(
tasks=tasks,
metrics=[
{
"metric": f"Metric{m_idx}",
"variants": [f"Variant{v_idx}" for v_idx in range(4)],
}
for m_idx in range(2)
],
)
self._assert_metrics_and_variants(
data.metrics,
metrics=2,
variants=4,
tasks=tasks,
iterations=iter_count,
)
def _assert_metrics_and_variants(
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
):
@@ -515,6 +604,13 @@ class TestTaskEvents(TestService):
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
# test metrics
plots = self.api.events.get_multi_task_plots(
tasks=[task1, task2], metrics=[{"metric": "A"}]
).plots
self.assertEqual(len(plots), 1)
self.assertEqual(len(plots.A), 2)
def test_task_plots(self):
task = self._temp_task()
event = self._create_task_event("plot", task, 0)

View File

@@ -7,13 +7,17 @@ from humanfriendly import parse_timespan
def setup():
from apiserver.database import db
db.initialize()
def gen_token(args):
from apiserver.bll.auth import AuthBLL
resp = AuthBLL.get_token_for_user(args.user_id, args.company_id, parse_timespan(args.expiration))
print('Token:\n%s' % resp.token)
resp = AuthBLL.get_token_for_user(
args.user_id, args.company_id, int(parse_timespan(args.expiration))
)
print("Token:\n%s" % resp.token)
def safe_get(obj, glob, default=None, separator="/"):
@@ -23,19 +27,24 @@ def safe_get(obj, glob, default=None, separator="/"):
return default
if __name__ == '__main__':
if __name__ == "__main__":
top_parser = ArgumentParser(__doc__)
subparsers = top_parser.add_subparsers(title='Sections')
subparsers = top_parser.add_subparsers(title="Sections")
token = subparsers.add_parser('token')
token_commands = token.add_subparsers(title='Commands')
token_create = token_commands.add_parser('generate', description='Generate a new token')
token_create.add_argument('--user-id', '-u', help='User ID', required=True)
token_create.add_argument('--company-id', '-c', help='Company ID', required=True)
token_create.add_argument('--expiration', '-exp',
help="Token expiration (time span, shorthand suffixes are supported, default 1m)",
default=parse_timespan('1m'))
token = subparsers.add_parser("token")
token_commands = token.add_subparsers(title="Commands")
token_create = token_commands.add_parser(
"generate", description="Generate a new token"
)
token_create.add_argument("--user-id", "-u", help="User ID", required=True)
token_create.add_argument("--company-id", "-c", help="Company ID", required=True)
token_create.add_argument(
"--expiration",
"-exp",
help="Token expiration (time span, shorthand suffixes are supported, default 1m)",
default=parse_timespan("1m"),
)
token_create.set_defaults(_func=gen_token)
args = top_parser.parse_args()

View File

@@ -1 +1 @@
__version__ = "1.13.0"
__version__ = "1.15.0"

View File

@@ -1,4 +1,4 @@
FROM node:18-bullseye as webapp_builder
FROM node:20-bookworm-slim as webapp_builder
ARG CLEARML_WEB_GIT_URL=https://github.com/allegroai/clearml-web.git
@@ -10,8 +10,9 @@ RUN mv clearml-web /opt/open-webapp
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
FROM python:3.9-slim-bullseye
FROM python:3.9-slim-bookworm
COPY --chmod=744 docker/build/internal_files/entrypoint.sh /opt/clearml/
COPY --chmod=744 docker/build/internal_files/update_from_env.py /opt/clearml/utilities/
COPY fileserver /opt/clearml/fileserver/
COPY apiserver /opt/clearml/apiserver/

View File

@@ -29,7 +29,12 @@ server {
include /etc/nginx/default.d/*.conf;
location / {
try_files $uri$args $uri$args/ $uri index.html /index.html;
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always;
add_header Content-Security-Policy "frame-ancestors 'self'";
add_header X-XSS-Protection "1; mode=block";
add_header X-Content-Type-Options "nosniff" always;
add_header Referrer-Policy "no-referrer-when-downgrade";
try_files $uri $uri/ /index.html;
}
location /version.json {
@@ -50,6 +55,12 @@ server {
rewrite /files/(.*) /$1 break;
}
location /widgets {
alias /usr/share/nginx/widgets;
try_files $uri $uri/ /widgets/index.html;
add_header Content-Security-Policy "frame-ancestors *";
}
error_page 404 /404.html;
location = /40x.html {
}
@@ -57,4 +68,4 @@ server {
error_page 500 502 503 504 /50x.html;
location = /50x.html {
}
}
}

View File

@@ -46,10 +46,26 @@ elif [[ ${SERVER_TYPE} == "webserver" ]]; then
EOF
fi
# Create an empty configuration json
echo "{}" > /tmp/configuration.json
# Copy the external configuration file if it exists
if test -f "/mnt/external_files/configs/configuration.json"; then
echo "Copying external configuration"
cp /mnt/external_files/configs/configuration.json /tmp/configuration.json
fi
# Update from env variables
echo "Updating configuration from env"
/opt/clearml/utilities/update_from_env.py \
--verbose \
/tmp/configuration.json \
/usr/share/nginx/html/configuration.json
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
COMMENT_IPV6_LISTEN=$([ "$DISABLE_NGINX_IPV6" = "true" ] && echo "#" || echo "") \
envsubst '${COMMENT_IPV6_LISTEN} ${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/sites-enabled/default
export COMMENT_IPV6_LISTEN=$([ "$DISABLE_NGINX_IPV6" = "true" ] && echo "#" || echo "")
envsubst '${COMMENT_IPV6_LISTEN} ${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/sites-enabled/default
if [[ -n "${CLEARML_SERVER_SUB_PATH}" ]]; then
mkdir -p /etc/nginx/default.d/

View File

@@ -1,11 +1,11 @@
#!/usr/bin/env bash
set -x
set -o errexit
set -o nounset
set -o pipefail
apt-get update -y
apt-get install -y python3-setuptools python3-dev build-essential nginx gettext
apt-get install -y vim curl
apt-get install -y python3-setuptools python3-dev build-essential nginx gettext vim curl
python3 -m ensurepip
python3 -m pip install --upgrade pip

View File

@@ -0,0 +1,104 @@
#!/usr/bin/env python3
""" Update json configuration file from environment variables """
from argparse import ArgumentParser, FileType
import json
from os import environ
from typing import Any, Generator, Tuple, Optional, List
class PathConflictError(Exception):
def __init__(self, path_: List[str]):
self.path = path_
def scan(
obj: Any, path_: str = None, sep: str = ".", parent_=None, key_=None,
) -> Generator[Tuple[str, Any, Optional[dict], str], None, None]:
if not isinstance(obj, dict):
yield path_.lower(), obj, parent_, key_
else:
for k, v in obj.items():
yield from scan(v, path_=sep.join(filter(None, (path_, k))), parent_=obj, key_=k, sep=sep)
def set_path(p: List[str], obj: dict, v: Any):
key_, *rest = p
if not rest:
obj[key_] = v
else:
if key_ in obj:
if not isinstance(obj[key_], dict):
raise PathConflictError(rest)
else:
obj[key_] = {}
return set_path(rest, obj[key_], v)
if __name__ == '__main__':
parser = ArgumentParser(description=__doc__)
parser.add_argument("input_file", type=FileType(), help="Input JSON file")
parser.add_argument("output_file", type=FileType("w"), help="Output JSON file")
parser.add_argument(
"--env-prefix", "-p", default="WEBSERVER", help="Environment variables prefix (default=%(default)s)",
dest="prefix", required=False
)
parser.add_argument(
"--env-separator", "-s", default="__", help="Environment variable name separator (default=%(default)s)",
dest="sep"
)
parser.add_argument("--verbose", "-v", action="store_true", default=False)
parser.add_argument(
"--disable-parse-env-value", action="store_false", default=True, help="Don't parse env value as JSON",
dest="parse_env"
)
args = parser.parse_args()
if not args.prefix:
print("Error: script does not support an empty prefix")
exit(1)
data = None
try:
data = json.load(args.input_file)
except json.JSONDecodeError as ex:
print(f"Error parsing JSON file {args.input_file.name}: {str(ex)}")
exit(1)
def parse_value(k, v):
try:
return json.loads(v)
except json.JSONDecodeError as ex:
print(f"Error parsing {k} JSON value `{v}`: {str(ex)}")
exit(2)
prefix = args.prefix + args.sep
env_vars = {
k.lstrip(prefix): parse_value(k, v) if args.parse_env else v
for k, v in environ.items() if k.startswith(prefix)
}
for path, value, parent, key in scan(data, sep=args.sep):
if not (parent and key):
continue
match = next((k for k in env_vars if k.lower() == path), None)
if match:
replace = env_vars.pop(match)
parent[key] = replace
if args.verbose:
print(f"Replacing {path}={value} with {replace}")
for k, v in env_vars.items():
path = k.split(args.sep)
try:
set_path(path, data, v)
except PathConflictError as ex:
print(f"Error: failed setting value into {k}: {path[:-len(ex.path)]} is not a dictionary")
try:
json.dump(data, args.output_file, sort_keys=True, indent=2)
except Exception as ex:
print(f"Error writing JSON file {args.output_file.name}: {str(ex)}")
exit(3)

View File

@@ -49,13 +49,10 @@ services:
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: clearml
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
reindex.remote.whitelist: "'*.*'"
xpack.security.enabled: "false"
ulimits:
memlock:
@@ -64,7 +61,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
restart: unless-stopped
volumes:
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -93,7 +90,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:4.4.9
image: mongo:4.4.29
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -104,7 +101,7 @@ services:
networks:
- backend
container_name: clearml-redis
image: redis:5.0
image: redis:6.2
restart: unless-stopped
volumes:
- c:/opt/clearml/data/redis:/data

View File

@@ -49,13 +49,10 @@ services:
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: clearml
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
reindex.remote.whitelist: "'*.*'"
xpack.security.enabled: "false"
ulimits:
memlock:
@@ -64,7 +61,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
restart: unless-stopped
volumes:
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -92,7 +89,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:4.4.9
image: mongo:4.4.29
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -103,7 +100,7 @@ services:
networks:
- backend
container_name: clearml-redis
image: redis:5.0
image: redis:6.2
restart: unless-stopped
volumes:
- /opt/clearml/data/redis:/data

View File

@@ -1,8 +1,9 @@
boltons>=19.1.0
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=2.3.2
flask>=2.3.3
gunicorn>=20.1.0
pyhocon>=0.3.35
setuptools>=65.5.1
urllib3>=1.26.18
urllib3>=1.26.18
werkzeug>=3.0.1