mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c95c63ce0 | ||
|
|
73179f53c2 | ||
|
|
ddc8a76279 | ||
|
|
ac7ea0d477 | ||
|
|
3544ed19f8 | ||
|
|
5e68f053a0 | ||
|
|
7bd5fdad59 | ||
|
|
484c72aa0c | ||
|
|
2027afbed5 | ||
|
|
7d649f1964 | ||
|
|
8d237b3cae | ||
|
|
e8ee6ce72e | ||
|
|
5749ff0454 | ||
|
|
5189adf4f1 | ||
|
|
92a4e56c1f | ||
|
|
33528870ae | ||
|
|
85f5b8b6f6 | ||
|
|
6112910768 | ||
|
|
d3013ac285 | ||
|
|
88abf28287 | ||
|
|
6a1fc04d1e | ||
|
|
ee8eb03698 | ||
|
|
5799baae45 | ||
|
|
801e536c5e | ||
|
|
6e484ea8f4 | ||
|
|
a47e65d974 | ||
|
|
702b6dc9c8 | ||
|
|
db15f235e4 | ||
|
|
8c347f8fa9 | ||
|
|
768c3d80ff | ||
|
|
a5c3ef6385 | ||
|
|
11b7a384af | ||
|
|
9a70ade4a6 | ||
|
|
91ce140901 | ||
|
|
49084a9c49 | ||
|
|
8a99eb6812 | ||
|
|
811ab2bf4f | ||
|
|
3752db122b | ||
|
|
439911b84c | ||
|
|
262a301e28 | ||
|
|
a604451b01 | ||
|
|
88a7773621 | ||
|
|
35c4061992 |
@@ -13,6 +13,14 @@ from apiserver.config_repo import config
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class TaskRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
|
||||
|
||||
class ModelRequest(Base):
|
||||
model: str = StringField(required=True)
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
@@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetMetricsAndVariantsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str,
|
||||
@@ -41,6 +54,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
)
|
||||
],
|
||||
)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
@@ -50,6 +64,12 @@ class TaskMetric(Base):
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class LegacyMetricEventsRequest(TaskRequest):
|
||||
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MetricEventsRequest(Base):
|
||||
metrics: Sequence[TaskMetric] = ListField(
|
||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||
@@ -58,7 +78,14 @@ class MetricEventsRequest(Base):
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class VectorMetricsIterHistogramRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetVariantSampleRequest(Base):
|
||||
@@ -109,6 +136,11 @@ class TaskEventsRequest(TaskEventsRequestBase):
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LegacyLogEventsRequest(TaskEventsRequestBase):
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField(default=5000)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
@@ -148,18 +180,28 @@ class MultiTasksRequestBase(Base):
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
pass
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class TaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class MultiTaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
|
||||
|
||||
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
|
||||
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
@@ -171,6 +213,14 @@ class TaskPlotsRequest(Base):
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetScalarMetricDataRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
|
||||
|
||||
|
||||
class GetSupportedModesRequest(Base):
|
||||
state = StringField(help_text="ASCII base64 encoded application state")
|
||||
callback_url_prefix = StringField()
|
||||
pass
|
||||
# state = StringField(help_text="ASCII base64 encoded application state")
|
||||
# callback_url_prefix = StringField()
|
||||
|
||||
|
||||
class BasicGuestMode(Base):
|
||||
|
||||
@@ -42,6 +42,21 @@ class ModelRequest(models.Base):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
|
||||
|
||||
class UpdateForTaskRequest(TaskRequest):
|
||||
uri = fields.StringField()
|
||||
iteration = fields.IntField()
|
||||
override_model_id = fields.StringField()
|
||||
|
||||
|
||||
class UpdateModelRequest(ModelRequest):
|
||||
task = fields.StringField()
|
||||
iteration = fields.IntField()
|
||||
|
||||
|
||||
class DeleteModelRequest(ModelRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_external_artifacts = fields.BoolField(default=True)
|
||||
|
||||
@@ -18,8 +18,4 @@ class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
||||
verify_watched_queue = fields.BoolField(default=False)
|
||||
|
||||
@@ -33,6 +33,7 @@ class ProjectOrNoneRequest(models.Base):
|
||||
|
||||
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
|
||||
model_metrics = fields.BoolField(default=False)
|
||||
ids = fields.ListField(str)
|
||||
|
||||
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
@@ -45,7 +46,7 @@ class ProjectTagsRequest(TagsRequest):
|
||||
|
||||
|
||||
class MultiProjectRequest(models.Base):
|
||||
projects = fields.ListField(str)
|
||||
projects = fields.ListField(items_types=[str, type(None)])
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
|
||||
enabled = BoolField(default=None, nullable=True)
|
||||
|
||||
|
||||
class GetConfigRequest(Base):
|
||||
path = StringField()
|
||||
|
||||
|
||||
class ReportStatsOptionResponse(Base):
|
||||
supported = BoolField(default=True)
|
||||
enabled = BoolField()
|
||||
|
||||
@@ -4,6 +4,10 @@ from jsonmodels.models import Base
|
||||
from apiserver.apimodels import DictField
|
||||
|
||||
|
||||
class UserRequest(Base):
|
||||
user = StringField(required=True)
|
||||
|
||||
|
||||
class CreateRequest(Base):
|
||||
id = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
|
||||
@@ -31,6 +31,7 @@ from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.task.utils import get_many_tasks_for_writing
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -42,7 +43,7 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
@@ -55,7 +56,9 @@ MIN_LONG = -(2**63)
|
||||
|
||||
log = config.logger(__file__)
|
||||
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
async_delete_threshold = config.get("services.tasks.async_events_delete_threshold", 100_000)
|
||||
async_delete_threshold = config.get(
|
||||
"services.tasks.async_events_delete_threshold", 100_000
|
||||
)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
@@ -97,7 +100,9 @@ class EventBLL(object):
|
||||
return self._metrics
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set:
|
||||
def _get_valid_entities(
|
||||
company_id, ids: Mapping[str, bool], identity: Identity, model=False
|
||||
) -> Set:
|
||||
"""Verify that task or model exists and can be updated"""
|
||||
if not ids:
|
||||
return set()
|
||||
@@ -116,20 +121,34 @@ class EventBLL(object):
|
||||
):
|
||||
if not requested_ids:
|
||||
continue
|
||||
query = Q(id__in=requested_ids, company=company_id)
|
||||
res.update(
|
||||
(Model if model else Task).objects(query & locked_q).scalar("id")
|
||||
)
|
||||
|
||||
query = Q(id__in=requested_ids) & locked_q
|
||||
if model:
|
||||
ids = Model.objects(query & Q(company=company_id)).scalar("id")
|
||||
else:
|
||||
ids = {
|
||||
t.id
|
||||
for t in get_many_tasks_for_writing(
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
query=query,
|
||||
only=("id",),
|
||||
throw_on_forbidden=False,
|
||||
)
|
||||
}
|
||||
|
||||
res.update(ids)
|
||||
|
||||
return res
|
||||
|
||||
def add_events(
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
events: Sequence[dict],
|
||||
worker: str,
|
||||
) -> Tuple[int, int, dict]:
|
||||
user_id = identity.user
|
||||
task_ids = {}
|
||||
model_ids = {}
|
||||
for event in events:
|
||||
@@ -161,8 +180,12 @@ class EventBLL(object):
|
||||
"Inconsistent model_event setting in the passed events",
|
||||
tasks=found_in_both,
|
||||
)
|
||||
valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True)
|
||||
valid_tasks = self._get_valid_entities(company_id, ids=task_ids)
|
||||
valid_models = self._get_valid_entities(
|
||||
company_id, ids=model_ids, identity=identity, model=True
|
||||
)
|
||||
valid_tasks = self._get_valid_entities(
|
||||
company_id, ids=task_ids, identity=identity
|
||||
)
|
||||
|
||||
actions: List[dict] = []
|
||||
used_task_ids = set()
|
||||
@@ -351,7 +374,7 @@ class EventBLL(object):
|
||||
if invalid_iterations_count:
|
||||
raise BulkIndexError(
|
||||
f"{invalid_iterations_count} document(s) failed to index.",
|
||||
[invalid_iteration_error],
|
||||
[{"_index": invalid_iteration_error}],
|
||||
)
|
||||
|
||||
if not added:
|
||||
@@ -415,10 +438,8 @@ class EventBLL(object):
|
||||
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
||||
key conflicts due to invalid characters and/or long field names.
|
||||
"""
|
||||
metric = event.get("metric")
|
||||
variant = event.get("variant")
|
||||
if not (metric and variant):
|
||||
return
|
||||
metric = event.get("metric") or ""
|
||||
variant = event.get("variant") or ""
|
||||
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
@@ -463,9 +484,9 @@ class EventBLL(object):
|
||||
recent than the currently stored event for its metric/event_type combination.
|
||||
last_events contains [metric_name -> event_type -> event]
|
||||
"""
|
||||
metric = event.get("metric")
|
||||
metric = event.get("metric") or ""
|
||||
event_type = event.get("type")
|
||||
if not (metric and event_type):
|
||||
if not event_type:
|
||||
return
|
||||
|
||||
timestamp = last_events[metric][event_type].get("timestamp", None)
|
||||
@@ -637,8 +658,8 @@ class EventBLL(object):
|
||||
Return events and next scroll id from the scrolled query
|
||||
Release the scroll once it is exhausted
|
||||
"""
|
||||
total_events = safe_get(es_res, "hits/total/value", default=0)
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
total_events = nested_get(es_res, ("hits", "total", "value"), default=0)
|
||||
events = [doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.clear_scroll(next_scroll_id)
|
||||
|
||||
@@ -9,7 +9,7 @@ from elasticsearch import Elasticsearch
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
@@ -123,8 +123,8 @@ def get_max_metric_and_variant_counts(
|
||||
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
|
||||
)
|
||||
|
||||
metrics_count = safe_get(
|
||||
es_res, "aggregations/metrics_count/value", max_metrics_count
|
||||
metrics_count = nested_get(
|
||||
es_res, ("aggregations", "metrics_count", "value"), max_metrics_count
|
||||
)
|
||||
if not metrics_count:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
@@ -21,9 +21,10 @@ from apiserver.bll.event.event_common import (
|
||||
TaskCompanies,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -161,7 +162,9 @@ class EventMetrics:
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, companies: TaskCompanies
|
||||
self,
|
||||
companies: TaskCompanies,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
@@ -179,7 +182,13 @@ class EventMetrics:
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_events = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(self._get_task_single_value_metrics, companies.items())
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_single_value_metrics,
|
||||
metric_variants=metric_variants,
|
||||
),
|
||||
companies.items(),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -195,19 +204,19 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, tasks: Tuple[str, Sequence[str]]
|
||||
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
|
||||
) -> Sequence[dict]:
|
||||
company_id, task_ids = tasks
|
||||
must = [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
@@ -280,7 +289,8 @@ class EventMetrics:
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
@@ -332,12 +342,12 @@ class EventMetrics:
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
count = nested_get(data, ("count", "value"), default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
min_index = nested_get(data, ("min_index", "value"), default=0)
|
||||
max_index = nested_get(data, ("max_index", "value"), default=min_index)
|
||||
index_range = max_index - min_index + 1
|
||||
interval = max(1, math.ceil(float(index_range) / samples))
|
||||
max_samples = math.ceil(float(index_range) / interval)
|
||||
@@ -366,7 +376,8 @@ class EventMetrics:
|
||||
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
@@ -432,7 +443,9 @@ class EventMetrics:
|
||||
|
||||
@classmethod
|
||||
def _get_task_metrics_query(
|
||||
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
|
||||
cls,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
):
|
||||
must = cls._task_conditions(task_id)
|
||||
if metrics:
|
||||
@@ -451,12 +464,96 @@ class EventMetrics:
|
||||
|
||||
return {"bool": {"must": must}}
|
||||
|
||||
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
|
||||
"""
|
||||
For the requested tasks return reported metrics and variants
|
||||
"""
|
||||
tasks_ids = {
|
||||
company: [t.id for t in tasks]
|
||||
for company, tasks in companies.items()
|
||||
}
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
companies_res: Sequence = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_multi_task_metrics,
|
||||
event_type=event_type,
|
||||
),
|
||||
tasks_ids.items(),
|
||||
)
|
||||
)
|
||||
|
||||
if len(companies_res) == 1:
|
||||
return companies_res[0]
|
||||
|
||||
res = defaultdict(set)
|
||||
for c_res in companies_res:
|
||||
for m, vars_ in c_res.items():
|
||||
res[m].update(vars_)
|
||||
|
||||
return {
|
||||
k: list(v)
|
||||
for k, v in res.items()
|
||||
}
|
||||
|
||||
def _get_multi_task_metrics(
|
||||
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
|
||||
) -> Mapping[str, list]:
|
||||
company_id, task_ids = company_tasks
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
)
|
||||
query = QueryBuilder.terms("task", task_ids)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
**search_args,
|
||||
)
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
return {
|
||||
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
|
||||
for mb in aggs_result["metrics"]["buckets"]
|
||||
}
|
||||
|
||||
def get_task_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
For the requested tasks return reported metrics per task
|
||||
"""
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
@@ -495,5 +592,5 @@ class EventMetrics:
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
|
||||
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"), default=[])
|
||||
]
|
||||
|
||||
@@ -6,7 +6,6 @@ from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Callable
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
@@ -27,6 +26,7 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
@@ -305,13 +305,13 @@ class MetricEventsIterator:
|
||||
return [
|
||||
MetricState(
|
||||
metric=metric["key"],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
timestamp=nested_get(metric, ("last_event_timestamp", "value")),
|
||||
variants=[
|
||||
init_variant_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
for variant in nested_get(metric, ("variants", "buckets"))
|
||||
],
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"))
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -430,14 +430,14 @@ class MetricEventsIterator:
|
||||
def get_iteration_events(it_: dict) -> Sequence:
|
||||
return [
|
||||
self._process_event(ev["_source"])
|
||||
for m in dpath.get(it_, "metrics/buckets")
|
||||
for v in dpath.get(m, "variants/buckets")
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
for m in nested_get(it_, ("metrics", "buckets"))
|
||||
for v in nested_get(m, ("variants", "buckets"))
|
||||
for ev in nested_get(v, ("events", "hits", "hits"))
|
||||
if is_valid_event(ev["_source"])
|
||||
]
|
||||
|
||||
iterations = []
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets"):
|
||||
for it in nested_get(es_res, ("aggregations", "iters", "buckets")):
|
||||
events = get_iteration_events(it)
|
||||
if events:
|
||||
iterations.append({"iter": it["key"], "events": events})
|
||||
|
||||
@@ -10,6 +10,7 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
@@ -57,14 +58,15 @@ class ModelBLL:
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force_publish_task: bool = False,
|
||||
publish_task_func: Callable[[str, str, str, bool], dict] = None,
|
||||
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
|
||||
) -> Tuple[int, ModelTaskPublishResponse]:
|
||||
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
if model.ready:
|
||||
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
||||
|
||||
user_id = identity.user
|
||||
published_task = None
|
||||
if model.task and publish_task_func:
|
||||
task = (
|
||||
@@ -74,7 +76,7 @@ class ModelBLL:
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
task_publish_res = publish_task_func(
|
||||
model.task, company_id, user_id, force_publish_task
|
||||
model.task, company_id, identity, force_publish_task
|
||||
)
|
||||
published_task = ModelTaskPublishResponse(
|
||||
id=model.task, data=task_publish_res
|
||||
|
||||
@@ -341,6 +341,17 @@ class ProjectBLL:
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def project_task_fields():
|
||||
return {
|
||||
"$project": {
|
||||
"project": 1,
|
||||
"status": 1,
|
||||
"system_tags": 1,
|
||||
"started": 1,
|
||||
"completed": 1,
|
||||
}
|
||||
}
|
||||
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
@@ -368,6 +379,7 @@ class ProjectBLL:
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
project_task_fields(),
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
@@ -516,6 +528,7 @@ class ProjectBLL:
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
project_task_fields(),
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
@@ -856,7 +869,7 @@ class ProjectBLL:
|
||||
company,
|
||||
project_ids: Sequence[str],
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> Set[str]:
|
||||
) -> Set[Union[str, type(None)]]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
@@ -1112,11 +1125,7 @@ class ProjectBLL:
|
||||
helper = GetMixin.NewListFieldBucketHelper(
|
||||
field, data=field_filter, legacy=True
|
||||
)
|
||||
op = (
|
||||
Q.OR
|
||||
if helper.explicit_operator and helper.global_operator == Q.OR
|
||||
else Q.AND
|
||||
)
|
||||
op = helper.global_operator
|
||||
db_query = {op: helper.actions}
|
||||
else:
|
||||
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
|
||||
@@ -1125,7 +1134,7 @@ class ProjectBLL:
|
||||
for op, actions in db_query.items():
|
||||
field_conditions = {}
|
||||
for action, values in actions.items():
|
||||
value = list(set(values))
|
||||
value = list(set(values)) if isinstance(values, list) else values
|
||||
for key in reversed(action.split("__")):
|
||||
value = {f"${key}": value}
|
||||
field_conditions.update(value)
|
||||
|
||||
@@ -239,6 +239,7 @@ class ProjectQueries:
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
ids: Sequence[str],
|
||||
model_metrics: bool = False,
|
||||
):
|
||||
pipeline = [
|
||||
@@ -246,6 +247,7 @@ class ProjectQueries:
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
**({"_id": {"$in": ids}} if ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
|
||||
@@ -152,7 +152,7 @@ class QueueBLL(object):
|
||||
|
||||
for item in queue.entries:
|
||||
try:
|
||||
task = Task.get_for_writing(
|
||||
task = Task.get(
|
||||
company=company_id,
|
||||
id=item.task,
|
||||
_only=[
|
||||
|
||||
@@ -18,7 +18,7 @@ from apiserver.config.info import get_deployment_type
|
||||
from apiserver.database.model import Company, User
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.json import dumps
|
||||
from apiserver.version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor, stat_threads
|
||||
@@ -162,7 +162,7 @@ class StatisticsReporter:
|
||||
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
|
||||
names = {"cpu": "num_cores"}
|
||||
return {
|
||||
names[c["key"]]: safe_get(c, "count/value")
|
||||
names[c["key"]]: nested_get(c, ("count", "value"))
|
||||
for c in categories
|
||||
if c["key"] in names
|
||||
}
|
||||
@@ -175,21 +175,21 @@ class StatisticsReporter:
|
||||
}
|
||||
return {
|
||||
names[m["key"]]: {
|
||||
"min": safe_get(m, "min/value"),
|
||||
"max": safe_get(m, "max/value"),
|
||||
"avg": safe_get(m, "avg/value"),
|
||||
"min": nested_get(m, ("min", "value")),
|
||||
"max": nested_get(m, ("max", "value")),
|
||||
"avg": nested_get(m, ("avg", "value")),
|
||||
}
|
||||
for m in metrics
|
||||
if m["key"] in names
|
||||
}
|
||||
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
|
||||
return {
|
||||
b["key"]: {
|
||||
key: {
|
||||
"interval_sec": agent_resource_threshold_sec,
|
||||
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
|
||||
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
|
||||
**_get_cardinality_fields(nested_get(b, ("categories", "buckets"), [])),
|
||||
**_get_metric_fields(nested_get(b, ("metrics", "buckets"), [])),
|
||||
}
|
||||
}
|
||||
for b in buckets
|
||||
@@ -227,7 +227,7 @@ class StatisticsReporter:
|
||||
},
|
||||
}
|
||||
res = cls._run_worker_stats_query(company_id, es_req)
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
|
||||
return {
|
||||
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
|
||||
for b in buckets
|
||||
@@ -254,6 +254,14 @@ class StatisticsReporter:
|
||||
**({"last_worker": {"$in": workers}} if workers else {}),
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"last_worker": 1,
|
||||
"last_update": 1,
|
||||
"started": 1,
|
||||
"last_iteration": 1,
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$last_worker" if workers else None,
|
||||
|
||||
@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
@@ -48,12 +49,14 @@ class Artifacts:
|
||||
def add_or_update_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
@@ -64,18 +67,20 @@ class Artifacts:
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
@@ -85,4 +90,4 @@ class Artifacts:
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
@@ -31,7 +32,10 @@ class HyperParams:
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -63,7 +67,7 @@ class HyperParams:
|
||||
def delete_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
@@ -74,6 +78,7 @@ class HyperParams:
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
@@ -96,7 +101,7 @@ class HyperParams:
|
||||
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
update_cmds=delete_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
@@ -105,7 +110,7 @@ class HyperParams:
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
@@ -117,6 +122,7 @@ class HyperParams:
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
@@ -135,7 +141,7 @@ class HyperParams:
|
||||
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
update_cmds=update_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
@@ -163,7 +169,10 @@ class HyperParams:
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -209,13 +218,15 @@ class HyperParams:
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
@@ -228,22 +239,24 @@ class HyperParams:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[str],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -58,27 +58,6 @@ class TaskBLL:
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
|
||||
@@ -9,6 +9,7 @@ from apiserver.bll.task import (
|
||||
ChangeStatusRequest,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
@@ -24,6 +25,7 @@ from apiserver.database.model.task.task import (
|
||||
DEFAULT_LAST_ITERATION,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -33,7 +35,7 @@ queue_bll = QueueBLL()
|
||||
def archive_task(
|
||||
task: Union[str, Task],
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
@@ -42,9 +44,10 @@ def archive_task(
|
||||
Return 1 if successful
|
||||
"""
|
||||
if isinstance(task, str):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
task,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
@@ -54,8 +57,9 @@ def archive_task(
|
||||
"system_tags",
|
||||
"enqueue_status",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
@@ -79,34 +83,34 @@ def archive_task(
|
||||
|
||||
|
||||
def unarchive_task(
|
||||
task: str,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Unarchive task. Return 1 if successful
|
||||
"""
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task,
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=("id",),
|
||||
requires_write_access=True,
|
||||
)
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=user_id,
|
||||
last_changed_by=identity.user,
|
||||
)
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
remove_from_all_queues: bool = False,
|
||||
@@ -119,7 +123,19 @@ def dequeue_task(
|
||||
task = Task.get(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=(
|
||||
_only=("id",),
|
||||
include_public=True,
|
||||
)
|
||||
if not task:
|
||||
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
|
||||
return 1, {"updated": 0}
|
||||
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
"execution",
|
||||
@@ -127,11 +143,7 @@ def dequeue_task(
|
||||
"project",
|
||||
"enqueue_status",
|
||||
),
|
||||
include_public=True,
|
||||
)
|
||||
if not task:
|
||||
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
|
||||
return 1, {"updated": 0}
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
@@ -148,7 +160,7 @@ def dequeue_task(
|
||||
def enqueue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
@@ -173,11 +185,11 @@ def enqueue_task(
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
if validate:
|
||||
TaskBLL.validate(task)
|
||||
|
||||
@@ -207,9 +219,9 @@ def enqueue_task(
|
||||
|
||||
# set the current queue ID in the task
|
||||
if task.execution:
|
||||
Task.objects(**query).update(execution__queue=queue_id, multi=False)
|
||||
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
|
||||
else:
|
||||
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
|
||||
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
|
||||
|
||||
nested_set(res, ("fields", "execution.queue"), queue_id)
|
||||
return 1, res
|
||||
@@ -242,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
@@ -251,8 +263,9 @@ def delete_task(
|
||||
status_reason: str,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -305,15 +318,16 @@ def delete_task(
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
@@ -392,14 +406,15 @@ def reset_task(
|
||||
def publish_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
publish_model_func: Callable[[str, str, str], Any] = None,
|
||||
publish_model_func: Callable[[str, str, Identity], Any] = None,
|
||||
status_message: str = "",
|
||||
status_reason: str = "",
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
@@ -422,7 +437,7 @@ def publish_task(
|
||||
.first()
|
||||
)
|
||||
if model and not model.ready:
|
||||
publish_model_func(model.id, company_id, user_id)
|
||||
publish_model_func(model.id, company_id, identity)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
@@ -446,7 +461,7 @@ def publish_task(
|
||||
def stop_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
@@ -459,10 +474,11 @@ def stop_task(
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
@@ -472,7 +488,6 @@ def stop_task(
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
@@ -10,6 +12,7 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
@@ -157,15 +160,78 @@ def get_possible_status_changes(current_status):
|
||||
return possible
|
||||
|
||||
|
||||
def get_many_tasks_for_writing(
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
query: Q = None,
|
||||
only: Sequence = None,
|
||||
throw_on_forbidden: bool = True,
|
||||
) -> Sequence[Task]:
|
||||
if only:
|
||||
missing = [f for f in ("company", ) if f not in only]
|
||||
if missing:
|
||||
only = [*only, *missing]
|
||||
|
||||
result = list(
|
||||
Task.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
override_projection=only,
|
||||
allow_public=True,
|
||||
return_dicts=False,
|
||||
)
|
||||
)
|
||||
|
||||
if not company_id:
|
||||
return result
|
||||
|
||||
forbidden_tasks = {task.id for task in result if not task.company}
|
||||
if forbidden_tasks:
|
||||
if throw_on_forbidden:
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
|
||||
)
|
||||
result = [task for task in result if task.id not in forbidden_tasks]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_task_with_write_access(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
only=None,
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
query = dict(id=task_id, company=company_id)
|
||||
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
identity: Identity,
|
||||
allow_all_statuses: bool = False,
|
||||
force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
"""
|
||||
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
only=("id", "status"),
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
@@ -27,10 +27,9 @@ from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@@ -287,7 +286,7 @@ class WorkerBLL:
|
||||
filter(
|
||||
None,
|
||||
(
|
||||
safe_get(info, "next_entry/task")
|
||||
nested_get(info, ("next_entry", "task"))
|
||||
for info in queues_info.values()
|
||||
),
|
||||
)
|
||||
@@ -311,7 +310,7 @@ class WorkerBLL:
|
||||
continue
|
||||
entry.name = info.get("name", None)
|
||||
entry.num_tasks = info.get("num_entries", 0)
|
||||
task_id = safe_get(info, "next_entry/task")
|
||||
task_id = nested_get(info, ("next_entry", "task"))
|
||||
if task_id:
|
||||
task = tasks_info.get(task_id, None)
|
||||
entry.next_task = IdNameEntry(
|
||||
|
||||
@@ -6,7 +6,7 @@ from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
from pathlib import Path
|
||||
from typing import List, Any, TypeVar, Sequence
|
||||
from typing import List, Any, TypeVar, Sequence, Set
|
||||
|
||||
from boltons.iterutils import first
|
||||
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
|
||||
@@ -35,6 +35,7 @@ class BasicConfig:
|
||||
folder: str = None,
|
||||
verbose: bool = True,
|
||||
prefix: Sequence[str] = DEFAULT_PREFIXES,
|
||||
exclude_files_from_base_folder: Sequence[str] = None,
|
||||
):
|
||||
folder = (
|
||||
Path(folder)
|
||||
@@ -44,6 +45,11 @@ class BasicConfig:
|
||||
if not folder.is_dir():
|
||||
raise ValueError("Invalid configuration folder")
|
||||
|
||||
self.exclude_files_from_base_folder = (
|
||||
set(exclude_files_from_base_folder)
|
||||
if exclude_files_from_base_folder
|
||||
else set()
|
||||
)
|
||||
self.verbose = verbose
|
||||
|
||||
self.extra_config_path_override_var = [
|
||||
@@ -85,7 +91,7 @@ class BasicConfig:
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_extra_env_config_values(self) -> ConfigTree:
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
"""Loads extra configuration from environment-injected values"""
|
||||
result = ConfigTree()
|
||||
|
||||
for prefix in self.extra_config_values_env_key_prefix:
|
||||
@@ -125,12 +131,18 @@ class BasicConfig:
|
||||
def _reload(self) -> ConfigTree:
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
|
||||
configs = [self._read_recursive(path) for path in self._paths]
|
||||
configs = [
|
||||
self._read_recursive(
|
||||
path,
|
||||
exclude_files=(
|
||||
self.exclude_files_from_base_folder if idx == 0 else None
|
||||
),
|
||||
)
|
||||
for idx, path in enumerate(self._paths)
|
||||
]
|
||||
|
||||
return reduce(
|
||||
lambda last, config: self._merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
lambda last, config: self._merge_configs(last, config, copy_trees=True),
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
@@ -141,9 +153,14 @@ class BasicConfig:
|
||||
for key, value in b.items():
|
||||
override = key.startswith(override_prefix)
|
||||
if override:
|
||||
key = key[len(override_prefix):]
|
||||
key = key[len(override_prefix) :]
|
||||
# if key is in both a and b and both values are dictionary then merge it otherwise override it
|
||||
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
|
||||
if (
|
||||
not override
|
||||
and key in a
|
||||
and isinstance(a[key], ConfigTree)
|
||||
and isinstance(b[key], ConfigTree)
|
||||
):
|
||||
if copy_trees:
|
||||
a[key] = a[key].copy()
|
||||
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
|
||||
@@ -156,13 +173,15 @@ class BasicConfig:
|
||||
a[key] = value
|
||||
if a.root:
|
||||
if b.root:
|
||||
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
|
||||
a.history[key] = a.history.get(key, []) + b.history.get(
|
||||
key, [value]
|
||||
)
|
||||
else:
|
||||
a.history[key] = a.history.get(key, []) + [value]
|
||||
|
||||
return a
|
||||
|
||||
def _read_recursive(self, conf_root) -> ConfigTree:
|
||||
def _read_recursive(self, conf_root, exclude_files: Set[str]) -> ConfigTree:
|
||||
conf = ConfigTree()
|
||||
|
||||
if not conf_root:
|
||||
@@ -180,6 +199,8 @@ class BasicConfig:
|
||||
print(f"Loading config from {conf_root}")
|
||||
|
||||
for file in conf_root.rglob("*.conf"):
|
||||
if exclude_files and file.name in exclude_files:
|
||||
continue
|
||||
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
|
||||
conf.put(key, self._read_single_file(file))
|
||||
|
||||
|
||||
@@ -58,6 +58,9 @@
|
||||
# verify user tokens
|
||||
verify_user_tokens: false
|
||||
|
||||
# If set then users that were created from secure credentials or fixed user settings and are no longer in these settings will be deleted on startup
|
||||
delete_missing_autocreated_users: true
|
||||
|
||||
# max token expiration timeout in seconds (1 year)
|
||||
max_expiration_sec: 31536000
|
||||
|
||||
@@ -72,6 +75,7 @@
|
||||
httponly: true # allow only http to access the cookies (no JS etc)
|
||||
secure: false # not using HTTPS
|
||||
domain: null # Limit to localhost is not supported
|
||||
samesite: Lax
|
||||
max_age: 99999999999
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ fileserver = "http://localhost:8081"
|
||||
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
hosts: [{host: "127.0.0.1", port: 9200, scheme: http}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
@@ -13,10 +12,9 @@ elastic {
|
||||
}
|
||||
|
||||
workers {
|
||||
hosts: [{host:"127.0.0.1", port:9200}]
|
||||
hosts: [{host:"127.0.0.1", port:9200, scheme: http}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
|
||||
@@ -18,8 +18,9 @@ aws {
|
||||
{
|
||||
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
|
||||
host: "localhost:9000"
|
||||
key: "evg_user"
|
||||
secret: "evg_pass"
|
||||
key: "minioadmin"
|
||||
secret: "minioadmin"
|
||||
# region: my-server
|
||||
multipart: false
|
||||
secure: false
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ from textwrap import shorten
|
||||
|
||||
import dpath
|
||||
from dpath.exceptions import InvalidKeyName
|
||||
from elasticsearch import ElasticsearchException
|
||||
from elastic_transport import TransportError, ApiError
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from jsonmodels.errors import ValidationError as JsonschemaValidationError
|
||||
from mongoengine.errors import (
|
||||
@@ -210,9 +210,9 @@ def translate_errors_context(message=None, **kwargs):
|
||||
raise errors.bad_request.ValidationError(e.args[0])
|
||||
except BulkIndexError as e:
|
||||
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
|
||||
except ElasticsearchException as e:
|
||||
except (TransportError, ApiError) as e:
|
||||
raise errors.server_error.DataError(e, message, **kwargs)
|
||||
except InvalidKeyName:
|
||||
raise errors.server_error.DataError("invalid empty key encountered in data")
|
||||
except Exception as ex:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -4,6 +4,7 @@ from mongoengine import (
|
||||
EmbeddedDocumentListField,
|
||||
EmailField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
@@ -76,3 +77,6 @@ class User(DbModelMixin, AuthDocument):
|
||||
|
||||
email = EmailField(unique=True, sparse=True)
|
||||
""" Email uniquely identifying the user """
|
||||
|
||||
autocreated = BooleanField(default=False)
|
||||
""" Set to true if the user was auto created based on config settings"""
|
||||
|
||||
@@ -146,9 +146,10 @@ class GetMixin(PropsMixin):
|
||||
"__$any": Q.OR,
|
||||
"__$or": Q.OR,
|
||||
}
|
||||
default_operator = Q.OR
|
||||
default_global_operator = Q.AND
|
||||
default_context = Q.OR
|
||||
# not_all modifier currently not supported due to the backwards compatibility
|
||||
mongo_modifiers = {
|
||||
# not_all modifier currently not supported due to the backwards compatibility
|
||||
Q.AND: {True: "all", False: "nin"},
|
||||
Q.OR: {True: "in", False: "nin"},
|
||||
}
|
||||
@@ -165,24 +166,22 @@ class GetMixin(PropsMixin):
|
||||
self.allow_empty = False
|
||||
self.global_operator = None
|
||||
self.actions = defaultdict(list)
|
||||
self.explicit_operator = False
|
||||
|
||||
self._support_legacy = legacy
|
||||
current_context = self.default_operator
|
||||
current_context = self.default_context
|
||||
for d in self._get_next_term(data):
|
||||
if d.operator is not None:
|
||||
current_context = d.operator
|
||||
self._support_legacy = False
|
||||
if self.global_operator is None:
|
||||
self.global_operator = d.operator
|
||||
self.explicit_operator = True
|
||||
continue
|
||||
|
||||
if self.global_operator is None:
|
||||
self.global_operator = self.default_operator
|
||||
self.global_operator = self.default_global_operator
|
||||
|
||||
if d.reset:
|
||||
current_context = self.default_operator
|
||||
current_context = self.default_context
|
||||
self._support_legacy = legacy
|
||||
continue
|
||||
|
||||
@@ -195,7 +194,7 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
|
||||
if self.global_operator is None:
|
||||
self.global_operator = self.default_operator
|
||||
self.global_operator = self.default_global_operator
|
||||
|
||||
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
|
||||
unary_operator = None
|
||||
@@ -618,7 +617,20 @@ class GetMixin(PropsMixin):
|
||||
):
|
||||
if not vals:
|
||||
continue
|
||||
operations[self._db_modifiers[(op, include)]] = list(set(vals))
|
||||
|
||||
unique = set(vals)
|
||||
if None in unique:
|
||||
# noinspection PyTypeChecker
|
||||
unique.remove(None)
|
||||
if include:
|
||||
operations["size"] = 0
|
||||
else:
|
||||
operations["not__size"] = 0
|
||||
|
||||
if not unique:
|
||||
continue
|
||||
|
||||
operations[self._db_modifiers[(op, include)]] = list(unique)
|
||||
|
||||
self.db_query[op] = operations
|
||||
|
||||
@@ -656,7 +668,8 @@ class GetMixin(PropsMixin):
|
||||
|
||||
ops = []
|
||||
for action, vals in actions.items():
|
||||
if not vals:
|
||||
# cannot just check vals here since 0 is acceptable value
|
||||
if vals is None or vals == []:
|
||||
continue
|
||||
|
||||
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
|
||||
@@ -1283,22 +1296,6 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_many_for_writing(cls, company, *args, **kwargs):
|
||||
result = cls.get_many(
|
||||
company=company,
|
||||
*args,
|
||||
**dict(return_dicts=False, **kwargs),
|
||||
allow_public=True,
|
||||
)
|
||||
forbidden_objects = {obj.id for obj in result if not obj.company}
|
||||
if forbidden_objects:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
__user_set_allowed_fields = None
|
||||
|
||||
@@ -231,11 +231,12 @@ class Task(AttributedDocument):
|
||||
"parent",
|
||||
"hyperparams.*",
|
||||
"execution.queue",
|
||||
"models.input.model",
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
pattern_fields=("name", "comment", "report"),
|
||||
fields=("runtime.*", "models.input.model"),
|
||||
fields=("runtime.*",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
|
||||
@@ -4,34 +4,89 @@ Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch import Elasticsearch, exceptions
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
|
||||
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
|
||||
hosts: Sequence,
|
||||
key: Optional[str] = None,
|
||||
es_args: dict = None,
|
||||
http_auth: Tuple = None,
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
def _send_template(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
template_name = f.stem
|
||||
res = es.indices.put_template(name=template_name, body=data)
|
||||
return {"mapping": template_name, "result": res}
|
||||
def _send_component_template(ct_file):
|
||||
with ct_file.open() as json_data:
|
||||
body = json.load(json_data)
|
||||
template_name = f"{ct_file.stem}"
|
||||
res = es.cluster.put_component_template(name=template_name, body=body)
|
||||
return {"component_template": template_name, "result": res}
|
||||
|
||||
p = HERE / "mappings"
|
||||
if key:
|
||||
files = (p / key).glob("*.json")
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
def _send_index_template(it_file):
|
||||
with it_file.open() as json_data:
|
||||
body = json.load(json_data)
|
||||
template_name = f"{it_file.stem}"
|
||||
res = es.indices.put_index_template(name=template_name, body=body)
|
||||
return {"index_template": template_name, "result": res}
|
||||
|
||||
# def _send_legacy_template(f):
|
||||
# with f.open() as json_data:
|
||||
# data = json.load(json_data)
|
||||
# template_name = f.stem
|
||||
# res = es.indices.put_template(name=template_name, body=data)
|
||||
# return {"mapping": template_name, "result": res}
|
||||
|
||||
def _delete_legacy_templates(legacy_folder):
|
||||
res_list = []
|
||||
for lt in legacy_folder.glob("*.json"):
|
||||
template_name = lt.stem
|
||||
try:
|
||||
if not es.indices.get_template(name=template_name):
|
||||
continue
|
||||
res = es.indices.delete_template(name=template_name)
|
||||
except exceptions.NotFoundError:
|
||||
continue
|
||||
res_list.append({"deleted legacy mapping": template_name, "result": res})
|
||||
|
||||
return res_list
|
||||
|
||||
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
root = HERE / "index_templates"
|
||||
if key:
|
||||
folders = [root / key]
|
||||
else:
|
||||
folders = [f for f in root.iterdir() if f.is_dir()]
|
||||
|
||||
ret = []
|
||||
for f in folders:
|
||||
for ct in (f / "component_templates").glob("*.json"):
|
||||
ret.append(_send_component_template(ct))
|
||||
for it in f.glob("*.json"):
|
||||
ret.append(_send_index_template(it))
|
||||
|
||||
legacy_root = HERE / "mappings"
|
||||
for f in folders:
|
||||
legacy_f = legacy_root / f.stem
|
||||
if not legacy_f.exists() or not legacy_f.is_dir():
|
||||
continue
|
||||
ret.extend(_delete_legacy_templates(legacy_f))
|
||||
|
||||
return ret
|
||||
# p = HERE / "mappings"
|
||||
# if key:
|
||||
# files = (p / key).glob("*.json")
|
||||
# else:
|
||||
# files = p.glob("**/*.json")
|
||||
#
|
||||
# return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"template": {
|
||||
"settings": {
|
||||
"number_of_replicas": 0,
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"type": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"iter": {
|
||||
"type": "long"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"model_event": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
18
apiserver/elastic/index_templates/events/events_log.json
Normal file
18
apiserver/elastic/index_templates/events/events_log.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"index_patterns": "events-log-*",
|
||||
"template": {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"msg": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"level": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"priority": 500,
|
||||
"composed_of": ["events_common"]
|
||||
}
|
||||
18
apiserver/elastic/index_templates/events/events_plot.json
Normal file
18
apiserver/elastic/index_templates/events/events_plot.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"index_patterns": "events-plot-*",
|
||||
"template": {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"plot_str": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"plot_data": {
|
||||
"type": "binary"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"priority": 500,
|
||||
"composed_of": ["events_common"]
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"index_patterns": "events-training_debug_image-*",
|
||||
"template": {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"url": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"priority": 500,
|
||||
"composed_of": ["events_common"]
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"index_patterns": "events-training_stats_scalar-*",
|
||||
"priority": 500,
|
||||
"composed_of": ["events_common"]
|
||||
}
|
||||
31
apiserver/elastic/index_templates/workers/queue_metrics.json
Normal file
31
apiserver/elastic/index_templates/workers/queue_metrics.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"template": {
|
||||
"settings": {
|
||||
"number_of_replicas": 0,
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
43
apiserver/elastic/index_templates/workers/worker_stats.json
Normal file
43
apiserver/elastic/index_templates/workers/worker_stats.json
Normal file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"template": {
|
||||
"settings": {
|
||||
"number_of_replicas": 0,
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"category": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"unit": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,8 @@ from apiserver.config_repo import config
|
||||
from apiserver.elastic.apply_mappings import apply_mappings_to_cluster
|
||||
|
||||
log = config.logger(__file__)
|
||||
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
|
||||
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
@@ -78,6 +80,18 @@ def check_elastic_empty() -> bool:
|
||||
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
|
||||
)
|
||||
|
||||
def events_legacy_template():
|
||||
try:
|
||||
return es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError:
|
||||
return False
|
||||
|
||||
def events_template():
|
||||
try:
|
||||
return es.indices.get_index_template(name="events*")
|
||||
except exceptions.NotFoundError:
|
||||
return False
|
||||
|
||||
try:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
@@ -87,10 +101,7 @@ def check_elastic_empty() -> bool:
|
||||
http_auth=es_factory.get_credentials("events", cluster_conf),
|
||||
**cluster_conf.get("args", {}),
|
||||
)
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
return True
|
||||
return not (events_template() or events_legacy_template())
|
||||
except exceptions.ConnectionError as ex:
|
||||
if retry >= max_retries - 1:
|
||||
raise ElasticConnectionError(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
@@ -9,6 +10,8 @@ from elasticsearch import Elasticsearch
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
logging.getLogger('elasticsearch').setLevel(logging.WARNING)
|
||||
logging.getLogger('elastic_transport').setLevel(logging.WARNING)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_ELASTIC_SERVICE_HOST",
|
||||
@@ -32,6 +35,7 @@ if OVERRIDE_HOST:
|
||||
|
||||
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
OVERRIDE_PORT = int(OVERRIDE_PORT)
|
||||
log.info(f"Using override elastic port {OVERRIDE_PORT}")
|
||||
|
||||
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
|
||||
|
||||
@@ -450,6 +450,7 @@ class AWSStorage(Storage):
|
||||
else None,
|
||||
"use_ssl": cfg.secure,
|
||||
"verify": cfg.verify,
|
||||
"region_name": cfg.region or None,
|
||||
}
|
||||
name = base[len(scheme_prefix(self.scheme)) :]
|
||||
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Sequence, Union
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_default_company
|
||||
from apiserver.database.model.auth import Role
|
||||
from apiserver.database.model.auth import Role, User as AuthUser
|
||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
|
||||
from .pre_populate import PrePopulate
|
||||
@@ -60,14 +60,18 @@ def init_mongo_data():
|
||||
|
||||
fixed_mode = FixedUser.enabled()
|
||||
|
||||
internal_user_emails = set()
|
||||
for user, credentials in config.get("secure.credentials", {}).items():
|
||||
email = f"{user}@example.com"
|
||||
user_data = {
|
||||
"name": user,
|
||||
"role": credentials.role,
|
||||
"email": f"{user}@example.com",
|
||||
"email": email,
|
||||
"key": credentials.user_key,
|
||||
"secret": credentials.user_secret,
|
||||
"autocreated": True,
|
||||
}
|
||||
internal_user_emails.add(email.lower())
|
||||
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
|
||||
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
|
||||
if credentials.role == Role.user:
|
||||
@@ -82,8 +86,20 @@ def init_mongo_data():
|
||||
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
ensure_fixed_user(user, log=log)
|
||||
ensure_fixed_user(user, log=log, emails=internal_user_emails)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
|
||||
if internal_user_emails and config.get(
|
||||
f"apiserver.auth.delete_missing_autocreated_users", True
|
||||
):
|
||||
for user in AuthUser.objects(
|
||||
company=company_id, autocreated=True, email__nin=internal_user_emails
|
||||
):
|
||||
log.info(
|
||||
f"Removing user that is no longer in configuration: {user['id']}\t{user['email']}\t{user['name']}"
|
||||
)
|
||||
user.delete()
|
||||
|
||||
except Exception as ex:
|
||||
log.exception("Failed initializing mongodb")
|
||||
log.exception(f"Failed initializing mongodb: {str(ex)}")
|
||||
|
||||
@@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import (
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_default_company
|
||||
from apiserver.database.model import EntityVisibility, User
|
||||
from apiserver.database.model.auth import Role, User as AuthUser
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import (
|
||||
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
|
||||
TaskModelNames,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
@@ -66,6 +68,7 @@ class PrePopulate:
|
||||
export_tag_prefix = "Exported:"
|
||||
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||
metadata_filename = "metadata.json"
|
||||
users_filename = "users.json"
|
||||
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||
artifacts_ext = ".artifacts"
|
||||
img_source_regex = re.compile(
|
||||
@@ -78,6 +81,7 @@ class PrePopulate:
|
||||
project_cls: Type[Project]
|
||||
model_cls: Type[Model]
|
||||
user_cls: Type[User]
|
||||
auth_user_cls: Type[AuthUser]
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@classmethod
|
||||
@@ -90,6 +94,8 @@ class PrePopulate:
|
||||
cls.project_cls = cls._get_entity_type("database.model.project.Project")
|
||||
if not hasattr(cls, "user_cls"):
|
||||
cls.user_cls = cls._get_entity_type("database.model.User")
|
||||
if not hasattr(cls, "auth_user_cls"):
|
||||
cls.auth_user_cls = cls._get_entity_type("database.model.auth.User")
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
@@ -205,6 +211,8 @@ class PrePopulate:
|
||||
task_statuses: Sequence[str] = None,
|
||||
tag_exported_entities: bool = False,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
export_events: bool = True,
|
||||
export_users: bool = False,
|
||||
) -> Sequence[str]:
|
||||
cls._init_entity_types()
|
||||
|
||||
@@ -240,11 +248,15 @@ class PrePopulate:
|
||||
with ZipFile(file, **cls.zip_args) as zfile:
|
||||
if metadata:
|
||||
zfile.writestr(cls.metadata_filename, meta_str)
|
||||
if export_users:
|
||||
cls._export_users(zfile)
|
||||
artifacts = cls._export(
|
||||
zfile,
|
||||
entities=entities,
|
||||
hash_=hash_,
|
||||
tag_entities=tag_exported_entities,
|
||||
export_events=export_events,
|
||||
cleanup_users=not export_users,
|
||||
)
|
||||
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
|
||||
@@ -265,6 +277,9 @@ class PrePopulate:
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
|
||||
if created_files:
|
||||
print("Created files:\n" + "\n".join(file for file in created_files))
|
||||
|
||||
return created_files
|
||||
|
||||
@classmethod
|
||||
@@ -296,18 +311,26 @@ class PrePopulate:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
# Make sure we won't end up with an invalid company ID
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
user_mapping = cls._import_users(zfile, company_id)
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
|
||||
if not existing_user:
|
||||
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
cls._import(
|
||||
zfile,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
user_mapping=user_mapping,
|
||||
)
|
||||
|
||||
if artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
|
||||
@@ -438,7 +461,7 @@ class PrePopulate:
|
||||
projects: Sequence[str] = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||
entities = defaultdict(set)
|
||||
entities: Dict[Any] = defaultdict(set)
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
@@ -497,7 +520,6 @@ class PrePopulate:
|
||||
@classmethod
|
||||
def _cleanup_model(cls, model: Model):
|
||||
model.company = ""
|
||||
model.user = ""
|
||||
model.tags = cls._filter_out_export_tags(model.tags)
|
||||
|
||||
@classmethod
|
||||
@@ -505,7 +527,6 @@ class PrePopulate:
|
||||
task.comment = "Auto generated by Allegro.ai"
|
||||
task.status_message = ""
|
||||
task.status_reason = ""
|
||||
task.user = ""
|
||||
task.company = ""
|
||||
task.tags = cls._filter_out_export_tags(task.tags)
|
||||
if task.output:
|
||||
@@ -513,17 +534,32 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _cleanup_project(cls, project: Project):
|
||||
project.user = ""
|
||||
project.company = ""
|
||||
project.tags = cls._filter_out_export_tags(project.tags)
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
def _cleanup_auth_user(cls, user: AuthUser):
|
||||
user.company = ""
|
||||
for cred in user.credentials:
|
||||
if getattr(cred, "company", None):
|
||||
cred["company"] = ""
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def _cleanup_be_user(cls, user: User):
|
||||
user.company = ""
|
||||
user.preferences = None
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity, cleanup_users):
|
||||
if cleanup_users:
|
||||
entity.user = ""
|
||||
if entity_cls == cls.task_cls:
|
||||
cls._cleanup_task(entity)
|
||||
elif entity_cls == cls.model_cls:
|
||||
cls._cleanup_model(entity)
|
||||
elif entity == cls.project_cls:
|
||||
elif entity_cls == cls.project_cls:
|
||||
cls._cleanup_project(entity)
|
||||
|
||||
@classmethod
|
||||
@@ -633,6 +669,38 @@ class PrePopulate:
|
||||
else:
|
||||
print(f"Artifact {full_path} not found")
|
||||
|
||||
@classmethod
|
||||
def _export_users(cls, writer: ZipFile):
|
||||
auth_users = {
|
||||
user.id: cls._cleanup_auth_user(user)
|
||||
for user in cls.auth_user_cls.objects(role__in=(Role.admin, Role.user))
|
||||
}
|
||||
if not auth_users:
|
||||
return
|
||||
|
||||
be_users = {
|
||||
user.id: cls._cleanup_be_user(user)
|
||||
for user in cls.user_cls.objects(id__in=list(auth_users))
|
||||
}
|
||||
if not be_users:
|
||||
return
|
||||
|
||||
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
|
||||
print(f"Writing {len(auth_users)} users into {writer.filename}")
|
||||
data = {}
|
||||
for field, users in (("auth", auth_users), ("backend", be_users)):
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for user in users.values():
|
||||
w.write(user.to_json())
|
||||
data[field] = f.getvalue()
|
||||
|
||||
def get_field_bytes(k: str, v: bytes) -> bytes:
|
||||
return f'"{k}": '.encode("utf-8") + v
|
||||
|
||||
data_str = b",\n".join(get_field_bytes(k, v) for k, v in data.items())
|
||||
writer.writestr(cls.users_filename, b"{\n" + data_str + b"\n}")
|
||||
|
||||
@classmethod
|
||||
def _get_base_filename(cls, cls_: type):
|
||||
name = f"{cls_.__module__}.{cls_.__name__}"
|
||||
@@ -642,7 +710,13 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
|
||||
cls,
|
||||
writer: ZipFile,
|
||||
entities: dict,
|
||||
hash_,
|
||||
tag_entities: bool = False,
|
||||
export_events: bool = True,
|
||||
cleanup_users: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
@@ -656,18 +730,19 @@ class PrePopulate:
|
||||
if not items:
|
||||
continue
|
||||
base_filename = cls._get_base_filename(cls_)
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
if export_events:
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
)
|
||||
)
|
||||
)
|
||||
filename = base_filename + ".json"
|
||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for item in items:
|
||||
cls._cleanup_entity(cls_, item)
|
||||
cls._cleanup_entity(cls_, item, cleanup_users=cleanup_users)
|
||||
w.write(item.to_json())
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
@@ -717,7 +792,10 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _generate_new_ids(
|
||||
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
|
||||
cls,
|
||||
reader: ZipFile,
|
||||
entity_files: Sequence,
|
||||
metadata: Mapping[str, Any],
|
||||
) -> Mapping[str, str]:
|
||||
if not metadata or not any(
|
||||
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
|
||||
@@ -745,6 +823,68 @@ class PrePopulate:
|
||||
)
|
||||
return ids
|
||||
|
||||
@classmethod
|
||||
def _import_users(cls, reader: ZipFile, company_id: str = "") -> dict:
|
||||
"""
|
||||
Import users to db and return the mapping of old user ids to the new ones
|
||||
If no users were in the users file then the mapping was empty
|
||||
If the user in the file has the same email as one of the existing ones then this user is skipped
|
||||
and its id is mapped to the existing user with the same email
|
||||
If the user with the same id exists in backend or auth db then its creation is skipped
|
||||
"""
|
||||
users_file = first(
|
||||
fi for fi in reader.filelist if fi.orig_filename == cls.users_filename
|
||||
)
|
||||
if not users_file:
|
||||
return {}
|
||||
|
||||
existing_user_ids = set(cls.user_cls.objects().scalar("id")) | set(
|
||||
cls.auth_user_cls.objects().scalar("id")
|
||||
)
|
||||
existing_user_emails = {u.email: u.id for u in cls.auth_user_cls.objects()}
|
||||
user_id_mappings = {}
|
||||
|
||||
with reader.open(users_file) as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
auth_users = {u["_id"]: u for u in data["auth"]}
|
||||
be_users = {u["_id"]: u for u in data["backend"]}
|
||||
for uid, user in auth_users.items():
|
||||
email = user.get("email")
|
||||
existing_user_id = existing_user_emails.get(email)
|
||||
if existing_user_id:
|
||||
user_id_mappings[uid] = existing_user_id
|
||||
continue
|
||||
|
||||
user_id_mappings[uid] = uid
|
||||
if uid in existing_user_ids:
|
||||
continue
|
||||
|
||||
credentials = user.get("credentials", [])
|
||||
for c in credentials:
|
||||
if c.get("company") == "":
|
||||
c["company"] = company_id
|
||||
|
||||
if hasattr(cls.auth_user_cls, "sec_groups"):
|
||||
user_role = user.get("role", Role.user)
|
||||
if user_role == Role.user:
|
||||
user["sec_groups"] = ["30795571-a470-4717-a80d-e8705fc776bf"]
|
||||
else:
|
||||
user["sec_groups"] = [
|
||||
"c14a3cc6-1144-4896-8ea6-fb186ee19896",
|
||||
"30795571-a470-4717-a80d-e8705fc776bf",
|
||||
"30795571a4704717a80de8705897ytuyg",
|
||||
]
|
||||
|
||||
auth_user = cls.auth_user_cls.from_json(json.dumps(user), created=True)
|
||||
auth_user.company = company_id
|
||||
auth_user.save()
|
||||
be_user = cls.user_cls.from_json(json.dumps(be_users[uid]), created=True)
|
||||
be_user.company = company_id
|
||||
be_user.save()
|
||||
|
||||
return user_id_mappings
|
||||
|
||||
@classmethod
|
||||
def _import(
|
||||
cls,
|
||||
@@ -753,6 +893,7 @@ class PrePopulate:
|
||||
user_id: str = None,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
sort_tasks_by_last_updated: bool = True,
|
||||
user_mapping: Mapping[str, str] = None,
|
||||
):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
@@ -763,7 +904,7 @@ class PrePopulate:
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
and fi.orig_filename not in (cls.metadata_filename, cls.users_filename)
|
||||
]
|
||||
metadata = metadata or {}
|
||||
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
|
||||
@@ -773,7 +914,13 @@ class PrePopulate:
|
||||
full_name = splitext(entity_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
res = cls._import_entity(
|
||||
f, full_name, company_id, user_id, metadata, old_to_new_ids
|
||||
f,
|
||||
full_name=full_name,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
old_to_new_ids=old_to_new_ids,
|
||||
user_mapping=user_mapping,
|
||||
)
|
||||
if res:
|
||||
tasks = res
|
||||
@@ -794,7 +941,7 @@ class PrePopulate:
|
||||
with reader.open(events_file) as f:
|
||||
full_name = splitext(events_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
cls._import_events(f, company_id, user_id, task.id)
|
||||
cls._import_events(f, company_id, task.user, task.id)
|
||||
|
||||
@classmethod
|
||||
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
|
||||
@@ -874,7 +1021,7 @@ class PrePopulate:
|
||||
):
|
||||
old_path = old_field.split(".")
|
||||
old_model = nested_get(task_data, old_path)
|
||||
new_models = models.get(type_, [])
|
||||
new_models = [m for m in models.get(type_, []) if m.get("model") is not None]
|
||||
name = TaskModelNames[type_]
|
||||
if old_model and not any(
|
||||
m
|
||||
@@ -908,7 +1055,9 @@ class PrePopulate:
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
old_to_new_ids: Mapping[str, str] = None,
|
||||
user_mapping: Mapping[str, str] = None,
|
||||
) -> Optional[Sequence[Task]]:
|
||||
user_mapping = user_mapping or {}
|
||||
cls_ = cls._get_entity_type(full_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
tasks = []
|
||||
@@ -930,7 +1079,7 @@ class PrePopulate:
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
doc.user = user_mapping.get(doc.user, user_id) if doc.user else user_id
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, cls.project_cls):
|
||||
@@ -970,7 +1119,7 @@ class PrePopulate:
|
||||
ev["allow_locked"] = True
|
||||
cls.event_bll.add_events(
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
identity=Identity(user_id, company=company_id, role=Role.admin),
|
||||
events=events,
|
||||
worker="",
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: boo
|
||||
credentials = [] if revoke else [creds]
|
||||
|
||||
user_id = user_data.get("id", f"__{user_data['name']}__")
|
||||
autocreated = user_data.get("autocreated", False)
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
|
||||
@@ -37,6 +38,7 @@ def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: boo
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=credentials,
|
||||
autocreated=autocreated,
|
||||
)
|
||||
|
||||
user.save()
|
||||
@@ -59,7 +61,7 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
return user_id
|
||||
|
||||
|
||||
def ensure_fixed_user(user: FixedUser, log: Logger):
|
||||
def ensure_fixed_user(user: FixedUser, log: Logger, emails: set):
|
||||
db_user = User.objects(company=user.company, id=user.user_id).first()
|
||||
if db_user:
|
||||
# noinspection PyBroadException
|
||||
@@ -73,9 +75,12 @@ def ensure_fixed_user(user: FixedUser, log: Logger):
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
email = f"{user.user_id}@example.com"
|
||||
data["email"] = email
|
||||
data["role"] = Role.guest if user.is_guest else Role.user
|
||||
data["autocreated"] = True
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
|
||||
emails.add(email)
|
||||
|
||||
return _ensure_backend_user(user.user_id, user.company, user.name)
|
||||
|
||||
@@ -6,11 +6,11 @@ boto3>=1.26
|
||||
boto3-stubs[s3]>=1.26
|
||||
clearml>=1.10.3
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch==7.17.9
|
||||
elasticsearch==8.12.0
|
||||
fastjsonschema>=2.8
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
flask>=2.3.2
|
||||
flask>=2.3.3
|
||||
furl>=2.0.0
|
||||
google-cloud-storage>=2.8.0
|
||||
gunicorn>=20.1.0
|
||||
@@ -34,3 +34,4 @@ setuptools>=65.5.1
|
||||
six
|
||||
validators>=0.12.4
|
||||
urllib3>=1.26.18
|
||||
werkzeug>=3.0.1
|
||||
@@ -754,6 +754,42 @@ get_task_metrics{
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_metrics {
|
||||
"2.28" {
|
||||
description: """Get unique metrics and variants from the events of the specified type.
|
||||
Only events reported for the passed task or model ids are analyzed."""
|
||||
request {
|
||||
type: object
|
||||
required: [ tasks ]
|
||||
properties {
|
||||
tasks {
|
||||
description: task ids to get metrics from
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
model_events {
|
||||
description: If not set or set to false then passed ids are task ids otherwise model ids
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
event_type {
|
||||
"description": Event type. If not specified then metrics are collected from the reported events of all types
|
||||
"$ref": "#/definitions/event_type_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_log {
|
||||
"1.5" {
|
||||
description: "Get all 'log' events for this task"
|
||||
@@ -971,10 +1007,17 @@ get_task_events {
|
||||
}
|
||||
}
|
||||
"2.22": ${get_task_events."2.1"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then get retrieving model events. Otherwise task events
|
||||
default: false
|
||||
request.properties {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then get retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1156,6 +1199,13 @@ get_multi_task_plots {
|
||||
default: true
|
||||
}
|
||||
}
|
||||
"2.28": ${get_multi_task_plots."2.26"} {
|
||||
request.properties.metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_vector_metrics_and_variants {
|
||||
"2.1" {
|
||||
@@ -1342,6 +1392,13 @@ multi_task_scalar_metrics_iter_histogram {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.28": ${multi_task_scalar_metrics_iter_histogram."2.22"} {
|
||||
request.properties.metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_single_value_metrics {
|
||||
"2.20" {
|
||||
@@ -1369,6 +1426,13 @@ get_task_single_value_metrics {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.28": ${get_task_single_value_metrics."2.22"} {
|
||||
request.properties.metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_latest_scalar_values {
|
||||
"2.1" {
|
||||
|
||||
@@ -11,16 +11,7 @@ supported_modes {
|
||||
description: """ Return supported login modes."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
state {
|
||||
description: "ASCII base64 encoded application state"
|
||||
type: string
|
||||
}
|
||||
callback_url_prefix {
|
||||
description: "URL prefix used to generate the callback URL for each supported SSO provider"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
|
||||
@@ -79,4 +79,15 @@ start_pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.28": ${start_pipeline."2.17"} {
|
||||
request.properties.verify_watched_queue {
|
||||
description: If passed then check wheter there are any workers watiching the queue
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
response.properties.queue_watched {
|
||||
description: Returns true if there are workers or autscalers working with the queue
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -949,6 +949,13 @@ get_unique_metric_variants {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.28": ${get_unique_metric_variants."2.25"} {
|
||||
request.properties.ids {
|
||||
description: IDs of the tasks or models to get metrics from
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyperparam_values {
|
||||
"2.13" {
|
||||
|
||||
@@ -21,6 +21,11 @@ log = config.logger(__file__)
|
||||
class RequestHandlers:
|
||||
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
|
||||
_server_header = config.get("apiserver.response.headers.server", "clearml")
|
||||
_custom_cookie_settings = {
|
||||
c["name"]: c["settings"]
|
||||
for c in config.get("apiserver.auth.custom_cookies", {}).values()
|
||||
if c.get("enabled") and c.get("settings")
|
||||
}
|
||||
|
||||
def before_request(self):
|
||||
if request.method == "OPTIONS":
|
||||
@@ -29,7 +34,10 @@ class RequestHandlers:
|
||||
return
|
||||
|
||||
if request.content_encoding:
|
||||
return f"Content encoding is not supported ({request.content_encoding})", 415
|
||||
return (
|
||||
f"Content encoding is not supported ({request.content_encoding})",
|
||||
415,
|
||||
)
|
||||
|
||||
try:
|
||||
call = self._create_api_call(request)
|
||||
@@ -42,7 +50,10 @@ class RequestHandlers:
|
||||
response = redirect(call.result.redirect.url, call.result.redirect.code)
|
||||
else:
|
||||
headers = None
|
||||
disable_cache = False
|
||||
if call.result.filename:
|
||||
# make sure that downloaded files are not cached by the client
|
||||
disable_cache = True
|
||||
try:
|
||||
call.result.filename.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
@@ -61,10 +72,16 @@ class RequestHandlers:
|
||||
status=call.result.code,
|
||||
headers=headers,
|
||||
)
|
||||
if disable_cache:
|
||||
response.cache_control.no_store = True
|
||||
response.cache_control.max_age = 0
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
kwargs = config.get("apiserver.auth.cookies").copy()
|
||||
kwargs = (
|
||||
self._custom_cookie_settings.get(key)
|
||||
or config.get("apiserver.auth.cookies")
|
||||
).copy()
|
||||
if value is None:
|
||||
# Removing a cookie
|
||||
kwargs["max_age"] = 0
|
||||
@@ -81,7 +98,9 @@ class RequestHandlers:
|
||||
if company:
|
||||
try:
|
||||
# use no default value to allow setting a null domain as well
|
||||
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
|
||||
kwargs["domain"] = config.get(
|
||||
f"apiserver.auth.cookies_domain_override.{company}"
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@@ -108,11 +127,15 @@ class RequestHandlers:
|
||||
return v
|
||||
|
||||
for k, v in md.lists():
|
||||
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
|
||||
v = (
|
||||
[convert_value(x) for x in v]
|
||||
if (len(v) > 1 or k.endswith("[]"))
|
||||
else convert_value(v[0])
|
||||
)
|
||||
nested_set(body, k.rstrip("[]").split("."), v)
|
||||
|
||||
def _update_call_data(self, call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
"""Use request payload/form to fill call data or batched data"""
|
||||
if req.content_type == "application/json-lines":
|
||||
items = []
|
||||
for i, line in enumerate(req.data.splitlines()):
|
||||
@@ -142,6 +165,9 @@ class RequestHandlers:
|
||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||
return call
|
||||
|
||||
def _get_session_auth_cookie(self, req):
|
||||
return req.cookies.get(config.get("apiserver.auth.session_auth_cookie_name"))
|
||||
|
||||
def _create_api_call(self, req):
|
||||
call = None
|
||||
try:
|
||||
@@ -155,9 +181,7 @@ class RequestHandlers:
|
||||
|
||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||
# in any case, request headers always take precedence.
|
||||
auth_cookie = req.cookies.get(
|
||||
config.get("apiserver.auth.session_auth_cookie_name")
|
||||
)
|
||||
auth_cookie = self._get_session_auth_cookie(req)
|
||||
headers = (
|
||||
{}
|
||||
if not auth_cookie
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .auth import get_auth_func, authorize_impersonation
|
||||
from .auth import get_auth_func, authorize_impersonation, revoke_auth_token
|
||||
from .payload import Token, Basic, AuthType, Payload
|
||||
from .identity import Identity
|
||||
from .utils import get_client_id, get_secret_key
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
@@ -11,15 +12,16 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User, Entities, Credentials
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.redis_manager import redman
|
||||
from .fixed_user import FixedUser
|
||||
from .identity import Identity
|
||||
from .payload import Payload, Token, Basic, AuthType
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
entity_keys = set(get_options(Entities))
|
||||
|
||||
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
|
||||
_revoked_tokens_key = "revoked_tokens"
|
||||
redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
def get_auth_func(auth_type):
|
||||
@@ -41,8 +43,10 @@ def authorize_token(jwt_token, service, action, call):
|
||||
log.error(f"{msg} Call info: {info}")
|
||||
|
||||
try:
|
||||
return Token.from_encoded_token(jwt_token)
|
||||
|
||||
token = Token.from_encoded_token(jwt_token)
|
||||
if is_token_revoked(token):
|
||||
raise errors.unauthorized.InvalidToken("revoked token")
|
||||
return token
|
||||
except jwt.exceptions.InvalidKeyError as ex:
|
||||
log_error("Failed parsing token.")
|
||||
raise errors.unauthorized.InvalidToken(
|
||||
@@ -154,3 +158,23 @@ def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
|
||||
return bcrypt.checkpw(
|
||||
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
|
||||
)
|
||||
|
||||
|
||||
def is_token_revoked(token: Token) -> bool:
|
||||
if not isinstance(token, Token) or not token.session_id:
|
||||
return False
|
||||
|
||||
return redis.zscore(_revoked_tokens_key, token.session_id) is not None
|
||||
|
||||
|
||||
def revoke_auth_token(token: Token):
|
||||
if not isinstance(token, Token) or not token.session_id:
|
||||
return
|
||||
|
||||
timestamp_now = int(time())
|
||||
expiration_timestamp = token.exp
|
||||
if not expiration_timestamp:
|
||||
expiration_timestamp = timestamp_now + Token.default_expiration_sec
|
||||
|
||||
redis.zadd(_revoked_tokens_key, {token.session_id: expiration_timestamp})
|
||||
redis.zremrangebyscore(_revoked_tokens_key, min=0, max=timestamp_now)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import jwt
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
@@ -20,7 +22,15 @@ class Token(Payload):
|
||||
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
|
||||
|
||||
def __init__(
|
||||
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
|
||||
self,
|
||||
exp=None,
|
||||
iat=None,
|
||||
nbf=None,
|
||||
env=None,
|
||||
identity=None,
|
||||
session_id=None,
|
||||
entities=None,
|
||||
**_,
|
||||
):
|
||||
super(Token, self).__init__(
|
||||
AuthType.bearer_token, identity=identity, entities=entities
|
||||
@@ -28,8 +38,13 @@ class Token(Payload):
|
||||
self.exp = exp
|
||||
self.iat = iat
|
||||
self.nbf = nbf
|
||||
self._session_id = session_id
|
||||
self._env = env or config.get("env", "<unknown>")
|
||||
|
||||
@property
|
||||
def session_id(self):
|
||||
return self._session_id
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self._env
|
||||
@@ -102,8 +117,11 @@ class Token(Payload):
|
||||
expiration_sec = expiration_sec or cls.default_expiration_sec
|
||||
|
||||
now = datetime.utcnow()
|
||||
session_id = uuid4().hex
|
||||
|
||||
token = cls(identity=identity, entities=entities, iat=now)
|
||||
token = cls(
|
||||
identity=identity, entities=entities, iat=now, session_id=session_id
|
||||
)
|
||||
|
||||
if expiration_sec:
|
||||
# add 'expiration' claim
|
||||
|
||||
@@ -39,7 +39,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.27")
|
||||
_max_version = PartialVersion("2.29")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
|
||||
@@ -24,6 +24,7 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User, Role
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Token
|
||||
from apiserver.service_repo.auth.auth import is_token_revoked, revoke_auth_token
|
||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -35,7 +36,7 @@ log = config.logger(__file__)
|
||||
response_data_model=GetTokenResponse,
|
||||
)
|
||||
def login(call: APICall, *_, **__):
|
||||
""" Generates a token based on the authenticated user (intended for use with credentials) """
|
||||
"""Generates a token based on the authenticated user (intended for use with credentials)"""
|
||||
call.result.data_model = AuthBLL.get_token_for_user(
|
||||
user_id=call.identity.user,
|
||||
company_id=call.identity.company,
|
||||
@@ -48,6 +49,7 @@ def login(call: APICall, *_, **__):
|
||||
|
||||
@endpoint("auth.logout", min_version="2.2")
|
||||
def logout(call: APICall, *_, **__):
|
||||
revoke_auth_token(call.auth)
|
||||
call.result.set_auth_cookie(None)
|
||||
|
||||
|
||||
@@ -57,7 +59,7 @@ def logout(call: APICall, *_, **__):
|
||||
response_data_model=GetTokenResponse,
|
||||
)
|
||||
def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
||||
"""Generates a token based on a requested user and company. INTERNAL."""
|
||||
if call.identity.role not in Role.get_system_roles():
|
||||
if call.identity.role != Role.admin and call.identity.user != request.user:
|
||||
raise errors.bad_request.InvalidUserId(
|
||||
@@ -81,12 +83,14 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||
response_data_model=ValidateResponse,
|
||||
)
|
||||
def validate_token_endpoint(call: APICall, _, __):
|
||||
""" Validate a token and return identity if valid. INTERNAL. """
|
||||
"""Validate a token and return identity if valid. INTERNAL."""
|
||||
try:
|
||||
# if invalid, decoding will fail
|
||||
token = Token.from_encoded_token(call.data_model.token)
|
||||
call.result.data_model = ValidateResponse(
|
||||
valid=True, user=token.identity.user, company=token.identity.company
|
||||
valid=not is_token_revoked(token),
|
||||
user=token.identity.user,
|
||||
company=token.identity.company,
|
||||
)
|
||||
except Exception as e:
|
||||
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
|
||||
@@ -98,7 +102,7 @@ def validate_token_endpoint(call: APICall, _, __):
|
||||
response_data_model=CreateUserResponse,
|
||||
)
|
||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
""" Create a user from. INTERNAL. """
|
||||
"""Create a user from. INTERNAL."""
|
||||
if (
|
||||
call.identity.role not in Role.get_system_roles()
|
||||
and request.company != call.identity.company
|
||||
|
||||
@@ -31,6 +31,15 @@ from apiserver.apimodels.events import (
|
||||
GetMetricSamplesRequest,
|
||||
TaskMetric,
|
||||
MultiTaskPlotsRequest,
|
||||
MultiTaskMetricsRequest,
|
||||
LegacyLogEventsRequest,
|
||||
TaskRequest,
|
||||
GetMetricsAndVariantsRequest,
|
||||
ModelRequest,
|
||||
LegacyMetricEventsRequest,
|
||||
GetScalarMetricDataRequest,
|
||||
VectorMetricsIterHistogramRequest,
|
||||
LegacyMultiTaskEventsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||
@@ -38,6 +47,7 @@ from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
@@ -73,7 +83,7 @@ def add(call: APICall, company_id, _):
|
||||
data = call.data.copy()
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
events=[data],
|
||||
worker=call.worker,
|
||||
)
|
||||
@@ -88,22 +98,22 @@ def add_batch(call: APICall, company_id, _):
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
events=events,
|
||||
worker=call.worker,
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
def get_task_log_v1_5(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.get_task_log")
|
||||
def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
order = call.data.get("order") or "desc"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
order = request.order
|
||||
scroll_id = request.scroll_id
|
||||
batch_size = request.batch_size
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
@@ -117,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
|
||||
def get_task_log_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.get_task_log", min_version="1.7")
|
||||
def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
order = call.data.get("order") or "desc"
|
||||
order = request.order
|
||||
from_ = call.data.get("from") or "head"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
scroll_id = request.scroll_id
|
||||
batch_size = request.batch_size
|
||||
|
||||
scroll_order = "asc" if (from_ == "head") else "desc"
|
||||
|
||||
@@ -175,9 +185,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.download_task_log", required_fields=["task"])
|
||||
def download_task_log(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.download_task_log")
|
||||
def download_task_log(call, company_id, request: TaskRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
@@ -255,10 +265,12 @@ def download_task_log(call, company_id, _):
|
||||
call.result.raw_data = generate()
|
||||
|
||||
|
||||
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
|
||||
def get_vector_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
model_events = call.data["model_events"]
|
||||
@endpoint("events.get_vector_metrics_and_variants")
|
||||
def get_vector_metrics_and_variants(
|
||||
call, company_id, request: GetMetricsAndVariantsRequest
|
||||
):
|
||||
task_id = request.task
|
||||
model_events = request.model_events
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
task_id,
|
||||
@@ -271,10 +283,12 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
|
||||
def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
model_events = call.data["model_events"]
|
||||
@endpoint("events.get_scalar_metrics_and_variants")
|
||||
def get_scalar_metrics_and_variants(
|
||||
call, company_id, request: GetMetricsAndVariantsRequest
|
||||
):
|
||||
task_id = request.task
|
||||
model_events = request.model_events
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
task_id,
|
||||
@@ -290,18 +304,19 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
|
||||
@endpoint(
|
||||
"events.vector_metrics_iter_histogram",
|
||||
required_fields=["task", "metric", "variant"],
|
||||
)
|
||||
def vector_metrics_iter_histogram(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
model_events = call.data["model_events"]
|
||||
def vector_metrics_iter_histogram(
|
||||
call, company_id, request: VectorMetricsIterHistogramRequest
|
||||
):
|
||||
task_id = request.task
|
||||
model_events = request.model_events
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
task_id,
|
||||
model_events=model_events,
|
||||
)[0]
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
metric = request.metric
|
||||
variant = request.variant
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||
task_or_model.get_index_company(), task_id, metric, variant
|
||||
)
|
||||
@@ -402,13 +417,13 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
|
||||
def get_scalar_metric_data(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
@endpoint("events.get_scalar_metric_data")
|
||||
def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest):
|
||||
task_id = request.task
|
||||
metric = request.metric
|
||||
scroll_id = request.scroll_id
|
||||
no_scroll = request.no_scroll
|
||||
model_events = request.model_events
|
||||
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
@@ -433,9 +448,9 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
|
||||
def get_task_latest_scalar_values(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.get_task_latest_scalar_values")
|
||||
def get_task_latest_scalar_values(call, company_id, request: TaskRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
@@ -521,6 +536,7 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
),
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -548,17 +564,18 @@ def get_task_single_value_metrics(
|
||||
tasks=_get_single_value_metrics_response(
|
||||
companies=companies,
|
||||
value_metrics=event_bll.metrics.get_task_single_value_metrics(
|
||||
companies=companies
|
||||
companies=companies,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
|
||||
def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
@endpoint("events.get_multi_task_plots")
|
||||
def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest):
|
||||
task_ids = request.tasks
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
companies = _get_task_or_model_index_companies(company_id, task_ids)
|
||||
|
||||
@@ -591,10 +608,11 @@ def _get_multitask_plots(
|
||||
companies: TaskCompanies,
|
||||
last_iters: int,
|
||||
last_iters_per_task_metric: bool,
|
||||
metrics: MetricVariants = None,
|
||||
request_metrics: Sequence[ApiMetrics] = None,
|
||||
scroll_id=None,
|
||||
no_scroll=True,
|
||||
) -> Tuple[dict, int, str]:
|
||||
metrics = _get_metric_variants_from_request(request_metrics)
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
@@ -629,6 +647,7 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
|
||||
scroll_id=request.scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
last_iters_per_task_metric=request.last_iters_per_task_metric,
|
||||
request_metrics=request.metrics,
|
||||
)
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
@@ -638,11 +657,11 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", required_fields=["task"])
|
||||
def get_task_plots_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
@endpoint("events.get_task_plots")
|
||||
def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@@ -760,11 +779,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.debug_images", required_fields=["task"])
|
||||
def get_debug_images_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
@endpoint("events.debug_images")
|
||||
def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@@ -797,12 +816,12 @@ def get_debug_images_v1_7(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
||||
def get_debug_images_v1_8(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
model_events = call.data.get("model_events", False)
|
||||
@endpoint("events.debug_images", min_version="1.8")
|
||||
def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
model_events = request.model_events
|
||||
|
||||
tasks_or_model = _assert_task_or_model_exists(
|
||||
company_id,
|
||||
@@ -960,12 +979,35 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||
def delete_for_task(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@endpoint("events.get_multi_task_metrics")
|
||||
def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsRequest):
|
||||
companies = _get_task_or_model_index_companies(
|
||||
company_id, request.tasks, model_events=request.model_events
|
||||
)
|
||||
if not companies:
|
||||
return {"metrics": []}
|
||||
|
||||
metrics = event_bll.metrics.get_multi_task_metrics(
|
||||
companies=companies, event_type=request.event_type
|
||||
)
|
||||
res = [
|
||||
{
|
||||
"metric": m,
|
||||
"variants": sorted(vars_),
|
||||
}
|
||||
for m, vars_ in metrics.items()
|
||||
]
|
||||
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task")
|
||||
def delete_for_task(call, company_id, request: TaskRequest):
|
||||
task_id = request.task
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, return_tasks=False)
|
||||
get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
|
||||
)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.delete_task_events(
|
||||
company_id, task_id, allow_locked=allow_locked
|
||||
@@ -973,9 +1015,9 @@ def delete_for_task(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.delete_for_model", required_fields=["model"])
|
||||
def delete_for_model(call: APICall, company_id: str, _):
|
||||
model_id = call.data["model"]
|
||||
@endpoint("events.delete_for_model")
|
||||
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
|
||||
model_id = request.model
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
model_bll.assert_exists(company_id, model_id, return_models=False)
|
||||
@@ -990,7 +1032,9 @@ def delete_for_model(call: APICall, company_id: str, _):
|
||||
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
|
||||
task_id = request.task
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, return_tasks=False)
|
||||
get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
|
||||
)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.clear_task_log(
|
||||
company_id=company_id,
|
||||
|
||||
@@ -7,6 +7,7 @@ from apiserver.apimodels.login import (
|
||||
)
|
||||
from apiserver.config import info
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.service_repo.auth import revoke_auth_token
|
||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
@@ -37,4 +38,5 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
|
||||
|
||||
@endpoint("login.logout", min_version="2.13")
|
||||
def logout(call: APICall, _, __):
|
||||
revoke_auth_token(call.auth)
|
||||
call.result.set_auth_cookie(None)
|
||||
|
||||
@@ -21,6 +21,10 @@ from apiserver.apimodels.models import (
|
||||
ModelsPublishManyRequest,
|
||||
ModelsDeleteManyRequest,
|
||||
ModelsGetRequest,
|
||||
ModelRequest,
|
||||
TaskRequest,
|
||||
UpdateForTaskRequest,
|
||||
UpdateModelRequest,
|
||||
)
|
||||
from apiserver.apimodels.tasks import UpdateTagsRequest
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
@@ -28,6 +32,7 @@ from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import validate_id
|
||||
@@ -46,6 +51,7 @@ from apiserver.database.utils import (
|
||||
filter_fields,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
@@ -65,9 +71,9 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
|
||||
unescape_metadata(call, model_data)
|
||||
|
||||
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
@endpoint("models.get_by_id")
|
||||
def get_by_id(call: APICall, company_id, request: ModelRequest):
|
||||
model_id = request.model
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
@@ -85,12 +91,12 @@ def get_by_id(call: APICall, company_id, _):
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
def get_by_task_id(call: APICall, company_id, _):
|
||||
@endpoint("models.get_by_task_id")
|
||||
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
|
||||
|
||||
task_id = call.data["task"]
|
||||
task_id = request.task
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
@@ -155,7 +161,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
@endpoint("models.get_all")
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
@@ -191,7 +197,7 @@ create_fields = {
|
||||
"project": Project,
|
||||
"parent": Model,
|
||||
"framework": None,
|
||||
"design": None,
|
||||
"design": dict,
|
||||
"labels": dict,
|
||||
"ready": None,
|
||||
"metadata": list,
|
||||
@@ -234,28 +240,27 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call: APICall, company_id, _):
|
||||
@endpoint("models.update_for_task")
|
||||
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
|
||||
|
||||
task_id = call.data["task"]
|
||||
uri = call.data.get("uri")
|
||||
iteration = call.data.get("iteration")
|
||||
override_model_id = call.data.get("override_model_id")
|
||||
task_id = request.task
|
||||
uri = request.uri
|
||||
iteration = request.iteration
|
||||
override_model_id = request.override_model_id
|
||||
if not (uri or override_model_id) or (uri and override_model_id):
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"exactly one field is required", fields=("uri", "override_model_id")
|
||||
)
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("models", "execution", "name", "status", "project"),
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
@@ -343,7 +348,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
validate_task(company_id, call.identity, req_data)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
@@ -373,7 +378,7 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
# clear UI cache if URI is provided (model updated)
|
||||
fields["ui_cache"] = fields.pop("ui_cache", {})
|
||||
if "task" in fields:
|
||||
validate_task(company_id, fields)
|
||||
validate_task(company_id, call.identity, fields)
|
||||
|
||||
if "labels" in fields:
|
||||
labels = fields["labels"]
|
||||
@@ -403,13 +408,16 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
return fields
|
||||
|
||||
|
||||
def validate_task(company_id, fields: dict):
|
||||
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
|
||||
def validate_task(company_id: str, identity: Identity, fields: dict):
|
||||
task_id = fields["task"]
|
||||
get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
@endpoint("models.edit", response_data_model=UpdateResponse)
|
||||
def edit(call: APICall, company_id, request: UpdateModelRequest):
|
||||
model_id = request.model
|
||||
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
@@ -424,7 +432,7 @@ def edit(call: APICall, company_id, _):
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
iteration = request.iteration
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
@@ -456,13 +464,9 @@ def edit(call: APICall, company_id, _):
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
def _update_model(call: APICall, company_id, model_id):
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
@@ -498,11 +502,9 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.update", required_fields=["model"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call, company_id, _):
|
||||
call.result.data_model = _update_model(call, company_id)
|
||||
@endpoint("models.update", response_data_model=UpdateResponse)
|
||||
def update(call, company_id, request: UpdateModelRequest):
|
||||
call.result.data_model = _update_model(call, company_id, model_id=request.model)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -514,7 +516,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
|
||||
updated, published_task = ModelBLL.publish_model(
|
||||
model_id=request.model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
)
|
||||
@@ -533,7 +535,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
|
||||
func=partial(
|
||||
ModelBLL.publish_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
),
|
||||
@@ -625,7 +627,9 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
|
||||
)
|
||||
def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
|
||||
func=partial(
|
||||
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
|
||||
@@ -5,22 +5,24 @@ import attr
|
||||
|
||||
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
|
||||
from apiserver.apimodels.pipelines import (
|
||||
StartPipelineResponse,
|
||||
StartPipelineRequest,
|
||||
DeleteRunsRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import enqueue_task, delete_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskType
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
task_bll = TaskBLL()
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
|
||||
def _update_task_name(task: Task):
|
||||
@@ -57,7 +59,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=False,
|
||||
force=True,
|
||||
return_file_urls=False,
|
||||
@@ -79,9 +81,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
|
||||
call.result.data = dict(succeeded=succeeded, failed=failures)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
|
||||
)
|
||||
@endpoint("pipelines.start_pipeline")
|
||||
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
|
||||
hyperparams = None
|
||||
if request.args:
|
||||
@@ -108,10 +108,19 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
|
||||
queued, res = enqueue_task(
|
||||
task_id=task.id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
)
|
||||
extra = {}
|
||||
if request.verify_watched_queue and queued:
|
||||
res_queue = nested_get(res, ("fields", "execution.queue"))
|
||||
if res_queue:
|
||||
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
|
||||
|
||||
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
|
||||
call.result.data = dict(
|
||||
pipeline=task.id,
|
||||
enqueued=bool(queued),
|
||||
**extra,
|
||||
)
|
||||
|
||||
@@ -59,13 +59,12 @@ create_fields = {
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_by_id", required_fields=["project"])
|
||||
def get_by_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
project_id = call.data["project"]
|
||||
@endpoint("projects.get_by_id")
|
||||
def get_by_id(call: APICall, company: str, request: ProjectRequest):
|
||||
project_id = request.project
|
||||
|
||||
with translate_errors_context():
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(company)
|
||||
project = Project.objects(query).first()
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
@@ -147,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
requested_ids = data.get("id")
|
||||
if isinstance(requested_ids, str):
|
||||
requested_ids = [requested_ids]
|
||||
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=request.shallow_search,
|
||||
data,
|
||||
shallow_search=request.shallow_search,
|
||||
)
|
||||
selected_project_ids = None
|
||||
if request.active_users or request.children_type:
|
||||
@@ -246,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
if request.include_dataset_stats:
|
||||
dataset_stats = project_bll.get_dataset_stats(
|
||||
company=company_id, project_ids=project_ids, users=request.active_users,
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
users=request.active_users,
|
||||
)
|
||||
for project in projects:
|
||||
project["dataset_stats"] = dataset_stats.get(project["id"])
|
||||
@@ -255,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
def get_all(call: APICall):
|
||||
def get_all(call: APICall, company: str, _):
|
||||
data = call.data
|
||||
conform_tag_fields(call, data)
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=data.get("shallow_search", False),
|
||||
data,
|
||||
shallow_search=data.get("shallow_search", False),
|
||||
)
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
company=company,
|
||||
query_dict=data,
|
||||
query=_hidden_query(
|
||||
search_hidden=data.get("search_hidden"), ids=data.get("id")
|
||||
@@ -277,9 +281,11 @@ def get_all(call: APICall):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.create", required_fields=["name"], response_data_model=IdResponse,
|
||||
"projects.create",
|
||||
required_fields=["name"],
|
||||
response_data_model=IdResponse,
|
||||
)
|
||||
def create(call: APICall):
|
||||
def create(call: APICall, company: str, _):
|
||||
identity = call.identity
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -288,15 +294,15 @@ def create(call: APICall):
|
||||
|
||||
return IdResponse(
|
||||
id=ProjectBLL.create(
|
||||
user=identity.user, company=identity.company, **fields,
|
||||
user=identity.user,
|
||||
company=company,
|
||||
**fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call: APICall):
|
||||
@endpoint("projects.update", response_data_model=UpdateResponse)
|
||||
def update(call: APICall, company: str, request: ProjectRequest):
|
||||
"""
|
||||
update
|
||||
|
||||
@@ -309,9 +315,7 @@ def update(call: APICall):
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
updated = ProjectBLL.update(
|
||||
company=call.identity.company, project_id=call.data["project"], **fields
|
||||
)
|
||||
updated = ProjectBLL.update(company=company, project_id=request.project, **fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
@@ -375,11 +379,11 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: GetUniqueMetricsRequest
|
||||
):
|
||||
|
||||
metrics = project_queries.get_unique_metric_variants(
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
ids=request.ids,
|
||||
model_metrics=request.model_metrics,
|
||||
)
|
||||
|
||||
@@ -428,7 +432,6 @@ def get_model_metadata_values(
|
||||
request_data_model=GetParamsRequest,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
|
||||
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
|
||||
@@ -19,7 +19,9 @@ from apiserver.apimodels.reports import (
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.models import conform_model_data
|
||||
from apiserver.services.utils import process_include_subprojects, sort_tags_response
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
@@ -57,15 +59,15 @@ update_fields = {
|
||||
}
|
||||
|
||||
|
||||
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
|
||||
def _assert_report(company_id: str, task_id: str, identity: Identity, only_fields=None):
|
||||
if only_fields and "type" not in only_fields:
|
||||
only_fields += ("type",)
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=only_fields,
|
||||
requires_write_access=requires_write_access,
|
||||
)
|
||||
if task.type != TaskType.report:
|
||||
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
|
||||
@@ -78,6 +80,7 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
|
||||
task = _assert_report(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only_fields=("status",),
|
||||
)
|
||||
|
||||
@@ -265,7 +268,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
res["plots"] = _get_multitask_plots(
|
||||
companies=companies,
|
||||
last_iters=request.plots.iters,
|
||||
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
||||
request_metrics=request.plots.metrics,
|
||||
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
|
||||
)[0]
|
||||
|
||||
@@ -302,6 +305,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
|
||||
task = _assert_report(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
identity=call.identity,
|
||||
only_fields=("project",),
|
||||
)
|
||||
user_id = call.identity.user
|
||||
@@ -337,7 +341,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
|
||||
response_data_model=UpdateResponse,
|
||||
)
|
||||
def publish(call: APICall, company_id, request: PublishReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, identity=call.identity
|
||||
)
|
||||
updates = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
@@ -352,7 +358,9 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
|
||||
|
||||
@endpoint("reports.archive")
|
||||
def archive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, identity=call.identity
|
||||
)
|
||||
archived = task.update(
|
||||
status_message=request.message,
|
||||
status_reason="",
|
||||
@@ -366,7 +374,9 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
|
||||
@endpoint("reports.unarchive")
|
||||
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, identity=call.identity
|
||||
)
|
||||
unarchived = task.update(
|
||||
status_message=request.message,
|
||||
status_reason="",
|
||||
@@ -394,6 +404,7 @@ def delete(call: APICall, company_id, request: DeleteReportRequest):
|
||||
task = _assert_report(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
identity=call.identity,
|
||||
only_fields=("project",),
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -3,7 +3,11 @@ from datetime import datetime
|
||||
from pyhocon.config_tree import NoneValue
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
|
||||
from apiserver.apimodels.server import (
|
||||
ReportStatsOptionRequest,
|
||||
ReportStatsOptionResponse,
|
||||
GetConfigRequest,
|
||||
)
|
||||
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_version, get_build_number, get_commit_number
|
||||
@@ -22,8 +26,8 @@ def get_stats(call: APICall):
|
||||
|
||||
|
||||
@endpoint("server.config")
|
||||
def get_config(call: APICall):
|
||||
path = call.data.get("path")
|
||||
def get_config(call: APICall, _, request: GetConfigRequest):
|
||||
path = request.path
|
||||
if path:
|
||||
c = dict(config.get(path))
|
||||
else:
|
||||
|
||||
@@ -100,10 +100,17 @@ from apiserver.bll.task.task_operations import (
|
||||
unarchive_task,
|
||||
move_tasks_to_trash,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
|
||||
from apiserver.bll.task.utils import (
|
||||
update_task,
|
||||
get_task_for_update,
|
||||
deleted_prefix,
|
||||
get_many_tasks_for_writing,
|
||||
get_task_with_write_access,
|
||||
)
|
||||
from apiserver.bll.util import run_batch_operation, update_project_time
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
@@ -118,6 +125,7 @@ from apiserver.database.utils import (
|
||||
get_options,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
@@ -142,14 +150,34 @@ org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
def _assert_writable_tasks(
|
||||
company_id: str, identity: Identity, ids: Sequence[str], only=("id",)
|
||||
) -> Sequence[Task]:
|
||||
tasks = get_many_tasks_for_writing(
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
query=Q(id__in=ids),
|
||||
only=only,
|
||||
)
|
||||
missing_ids = set(ids) - {t.id for t in tasks}
|
||||
if missing_ids:
|
||||
raise errors.bad_request.InvalidTaskId(ids=list(missing_ids))
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def set_task_status_from_call(
|
||||
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
|
||||
request: UpdateRequest,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
new_status=None,
|
||||
**set_fields,
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
request.task,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=("id", "status", "project"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
status_reason = request.status_reason
|
||||
@@ -161,15 +189,17 @@ def set_task_status_from_call(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
).execute(**set_fields)
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
|
||||
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
req_model.task, company_id=company_id, allow_public=True
|
||||
)
|
||||
def get_by_id(call: APICall, company_id, request: TaskRequest):
|
||||
task = TaskBLL.assert_exists(
|
||||
company_id,
|
||||
task_ids=request.task,
|
||||
allow_public=True,
|
||||
)[0]
|
||||
task_dict = task.to_proper_dict()
|
||||
conform_task_data(call, task_dict)
|
||||
call.result.data = {"task": task_dict}
|
||||
@@ -227,14 +257,16 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
conform_task_data(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
@endpoint("tasks.get_all")
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
@@ -278,7 +310,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**stop_task(
|
||||
task_id=req_model.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=req_model.status_reason,
|
||||
force=req_model.force,
|
||||
@@ -296,7 +328,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
|
||||
func=partial(
|
||||
stop_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=request.status_reason,
|
||||
force=request.force,
|
||||
@@ -319,7 +351,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.stopped,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -332,13 +364,21 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
|
||||
response_data_model=StartedResponse,
|
||||
)
|
||||
def started(call: APICall, company_id, req_model: UpdateRequest):
|
||||
started_update = {}
|
||||
if Task.objects(id=req_model.task, started=None).only("id"):
|
||||
# this is the fix for older versions putting started to None on reset
|
||||
started_update["started"] = datetime.utcnow()
|
||||
else:
|
||||
# don't override a previous, smaller "started" field value
|
||||
started_update["min__started"] = datetime.utcnow()
|
||||
|
||||
res = StartedResponse(
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.in_progress,
|
||||
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
|
||||
**started_update,
|
||||
)
|
||||
)
|
||||
res.started = res.updated
|
||||
@@ -353,7 +393,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.failed,
|
||||
)
|
||||
)
|
||||
@@ -367,7 +407,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.closed,
|
||||
)
|
||||
)
|
||||
@@ -381,18 +421,19 @@ create_fields = {
|
||||
"error": None,
|
||||
"comment": None,
|
||||
"parent": Task,
|
||||
"project": None,
|
||||
"project": Project,
|
||||
"input": None,
|
||||
"models": None,
|
||||
"container": None,
|
||||
"container": dict,
|
||||
"output_dest": None,
|
||||
"execution": None,
|
||||
"hyperparams": None,
|
||||
"configuration": None,
|
||||
"hyperparams": dict,
|
||||
"configuration": dict,
|
||||
"script": None,
|
||||
"runtime": None,
|
||||
"runtime": dict,
|
||||
}
|
||||
|
||||
|
||||
dict_fields_paths = [("execution", "model_labels"), "container"]
|
||||
|
||||
|
||||
@@ -433,13 +474,17 @@ def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
||||
|
||||
for data in tasks_data:
|
||||
params_unprepare_from_saved(
|
||||
fields=data, copy_to_legacy=need_legacy_params,
|
||||
fields=data,
|
||||
copy_to_legacy=need_legacy_params,
|
||||
)
|
||||
artifacts_unprepare_from_saved(fields=data)
|
||||
|
||||
|
||||
def prepare_create_fields(
|
||||
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
|
||||
call: APICall,
|
||||
valid_fields=None,
|
||||
output=None,
|
||||
previous_task: Task = None,
|
||||
):
|
||||
valid_fields = valid_fields if valid_fields is not None else create_fields
|
||||
t_fields = task_fields
|
||||
@@ -566,11 +611,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
task_id = req_model.task
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(
|
||||
id=task_id, company=company_id, _only=["id", "project"]
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("id", "project"),
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
|
||||
|
||||
@@ -582,7 +628,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
id=task_id,
|
||||
partial_update_dict=partial_update_dict,
|
||||
injected_update=dict(
|
||||
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
),
|
||||
)
|
||||
if updated_count:
|
||||
@@ -606,11 +653,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
|
||||
requirements = req_model.requirements
|
||||
with translate_errors_context():
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
req_model.task,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("status", "script"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
if not task.script:
|
||||
raise errors.bad_request.MissingTaskFields(
|
||||
@@ -636,8 +683,11 @@ def update_batch(call: APICall, company_id, _):
|
||||
items = {i["task"]: i for i in items}
|
||||
tasks = {
|
||||
t.id: t
|
||||
for t in Task.get_many_for_writing(
|
||||
company=company_id, query=Q(id__in=list(items))
|
||||
for t in _assert_writable_tasks(
|
||||
identity=call.identity,
|
||||
company_id=company_id,
|
||||
ids=list(items),
|
||||
only=("id", "project"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -656,7 +706,8 @@ def update_batch(call: APICall, company_id, _):
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
partial_update_dict.update(
|
||||
last_change=now, last_changed_by=call.identity.user,
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
update_op = UpdateOne(
|
||||
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
||||
@@ -690,9 +741,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
force = req_model.force
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(id=task_id, company=company_id)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
)
|
||||
|
||||
if not force and task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
@@ -756,7 +809,8 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
||||
"tasks.get_hyper_params",
|
||||
request_data_model=GetHyperParamsRequest,
|
||||
)
|
||||
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
@@ -771,7 +825,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
@@ -785,7 +839,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
@@ -794,7 +848,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
||||
"tasks.get_configurations",
|
||||
request_data_model=GetConfigurationsRequest,
|
||||
)
|
||||
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
@@ -809,7 +864,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
|
||||
"tasks.get_configuration_names",
|
||||
request_data_model=GetConfigurationNamesRequest,
|
||||
)
|
||||
def get_configuration_names(
|
||||
call: APICall, company_id, request: GetConfigurationNamesRequest
|
||||
@@ -830,7 +886,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
@@ -846,7 +902,7 @@ def delete_configuration(
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
@@ -863,7 +919,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
||||
queued, res = enqueue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -888,7 +944,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
func=partial(
|
||||
enqueue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -915,13 +971,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.dequeue", response_data_model=DequeueResponse,
|
||||
"tasks.dequeue",
|
||||
response_data_model=DequeueResponse,
|
||||
)
|
||||
def dequeue(call: APICall, company_id, request: DequeueRequest):
|
||||
dequeued, res = dequeue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
remove_from_all_queues=request.remove_from_all_queues,
|
||||
@@ -931,14 +988,15 @@ def dequeue(call: APICall, company_id, request: DequeueRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.dequeue_many", response_data_model=DequeueManyResponse,
|
||||
"tasks.dequeue_many",
|
||||
response_data_model=DequeueManyResponse,
|
||||
)
|
||||
def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
dequeue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
remove_from_all_queues=request.remove_from_all_queues,
|
||||
@@ -962,7 +1020,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
dequeued, cleanup_res, updates = reset_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
@@ -990,7 +1048,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
func=partial(
|
||||
reset_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
@@ -1027,9 +1085,11 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
response_data_model=ArchiveResponse,
|
||||
)
|
||||
def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
tasks = TaskBLL.assert_exists(
|
||||
archived = 0
|
||||
tasks = _assert_writable_tasks(
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
call.identity,
|
||||
ids=request.tasks,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
@@ -1040,11 +1100,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
"enqueue_status",
|
||||
),
|
||||
)
|
||||
archived = 0
|
||||
for task in tasks:
|
||||
archived += archive_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task=task,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1063,7 +1122,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
archive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1085,7 +1144,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
unarchive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1104,7 +1163,7 @@ def delete(call: APICall, company_id, request: DeleteRequest):
|
||||
deleted, task, cleanup_res = delete_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
@@ -1126,7 +1185,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
@@ -1164,7 +1223,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
|
||||
updates = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1183,7 +1242,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
|
||||
func=partial(
|
||||
publish_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model
|
||||
if request.publish_model
|
||||
@@ -1211,7 +1270,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
**set_task_status_from_call(
|
||||
request,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.completed,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -1221,7 +1280,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
publish_res = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1256,7 +1315,7 @@ def add_or_update_artifacts(
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=True,
|
||||
@@ -1273,7 +1332,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=True,
|
||||
@@ -1310,6 +1369,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
"project or project_name is required"
|
||||
)
|
||||
|
||||
_assert_writable_tasks(company_id, call.identity, request.ids)
|
||||
updated_projects = set(
|
||||
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
|
||||
)
|
||||
@@ -1330,7 +1390,8 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
|
||||
|
||||
@endpoint("tasks.update_tags")
|
||||
def update_tags(_, company_id: str, request: UpdateTagsRequest):
|
||||
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
|
||||
_assert_writable_tasks(company_id, call.identity, request.ids)
|
||||
return {
|
||||
"updated": org_bll.edit_entity_tags(
|
||||
company_id=company_id,
|
||||
@@ -1344,7 +1405,9 @@ def update_tags(_, company_id: str, request: UpdateTagsRequest):
|
||||
|
||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
get_task_for_update(
|
||||
company_id=company_id, task_id=request.task, force=True, identity=call.identity
|
||||
)
|
||||
|
||||
models_field = f"models__{request.type}"
|
||||
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
||||
@@ -1364,7 +1427,9 @@ def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelR
|
||||
|
||||
@endpoint("tasks.delete_models", min_version="2.13")
|
||||
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=request.task, force=True, identity=call.identity
|
||||
)
|
||||
|
||||
delete_names = {
|
||||
type_: [m.name for m in request.models if m.type == type_]
|
||||
@@ -1377,6 +1442,8 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
}
|
||||
|
||||
updated = task.update(
|
||||
last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
**commands,
|
||||
)
|
||||
return {"updated": updated}
|
||||
|
||||
@@ -7,12 +7,16 @@ from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest
|
||||
from apiserver.apimodels.users import (
|
||||
CreateRequest,
|
||||
SetPreferencesRequest,
|
||||
UserRequest,
|
||||
)
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import Role
|
||||
from apiserver.database.model.auth import Role, User as AuthUser
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.user import User
|
||||
from apiserver.database.utils import parse_from_call
|
||||
@@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None):
|
||||
return res.to_proper_dict()
|
||||
|
||||
|
||||
@endpoint("users.get_by_id", required_fields=["user"])
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
user_id = call.data["user"]
|
||||
@endpoint("users.get_by_id")
|
||||
def get_by_id(call: APICall, company_id, request: UserRequest):
|
||||
user_id = request.user
|
||||
call.result.data = {"user": get_user(call, company_id, user_id)}
|
||||
|
||||
|
||||
@endpoint("users.get_all_ex", required_fields=[])
|
||||
@endpoint("users.get_all_ex")
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
res = User.get_many_with_join(company=company_id, query_dict=call.data)
|
||||
@@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
call.result.data = {"users": res}
|
||||
|
||||
|
||||
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
|
||||
@endpoint("users.get_all_ex", min_version="2.8")
|
||||
def get_all_ex2_8(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
data = call.data
|
||||
@@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _):
|
||||
call.result.data = {"users": res}
|
||||
|
||||
|
||||
@endpoint("users.get_all", required_fields=[])
|
||||
@endpoint("users.get_all")
|
||||
def get_all(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
res = User.get_many(
|
||||
@@ -138,9 +142,9 @@ def create(call: APICall):
|
||||
UserBLL.create(call.data_model)
|
||||
|
||||
|
||||
@endpoint("users.delete", required_fields=["user"])
|
||||
def delete(call: APICall):
|
||||
UserBLL.delete(call.data["user"])
|
||||
@endpoint("users.delete")
|
||||
def delete(_: APICall, __, request: UserRequest):
|
||||
UserBLL.delete(request.user)
|
||||
|
||||
|
||||
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
||||
@@ -154,14 +158,22 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
||||
update_fields = {
|
||||
k: v for k, v in create_fields.items() if k in User.user_set_allowed()
|
||||
}
|
||||
auth_user_update_fields = ("name",)
|
||||
partial_update_dict = parse_from_call(data, update_fields, User.get_fields())
|
||||
with translate_errors_context("updating user"):
|
||||
return User.safe_update(company_id, user_id, partial_update_dict)
|
||||
ret = User.safe_update(company_id, user_id, partial_update_dict)
|
||||
auth_update = {
|
||||
k: v for k, v in partial_update_dict.items() if k in auth_user_update_fields
|
||||
}
|
||||
if auth_update:
|
||||
AuthUser.objects(id=user_id).update(**auth_update)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
|
||||
def update(call, company_id, _):
|
||||
user_id = call.data["user"]
|
||||
@endpoint("users.update", response_data_model=UpdateResponse)
|
||||
def update(call, company_id, request: UserRequest):
|
||||
user_id = request.user
|
||||
update_count, updated_fields = update_user(user_id, company_id, call.data)
|
||||
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
|
||||
|
||||
|
||||
@@ -3,6 +3,30 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestGetAllExFilters(TestService):
|
||||
def test_no_tags_filter(self):
|
||||
task = self._temp_task(tags=["test"])
|
||||
task_no_tags = self._temp_task()
|
||||
tasks = [task, task_no_tags]
|
||||
|
||||
for cond, op, tags, expected_tasks in (
|
||||
("any", "include", [None], [task_no_tags]),
|
||||
("any", "include", ["test"], [task]),
|
||||
("any", "include", ["test", None], [task, task_no_tags]),
|
||||
("any", "exclude", [None], [task]),
|
||||
("any", "exclude", ["test"], [task_no_tags]),
|
||||
("any", "exclude", ["test", None], [task, task_no_tags]),
|
||||
("all", "include", [None], [task_no_tags]),
|
||||
("all", "include", ["test"], [task]),
|
||||
("all", "include", ["test", None], []),
|
||||
("all", "exclude", [None], [task]),
|
||||
("all", "exclude", ["test"], [task_no_tags]),
|
||||
("all", "exclude", ["test", None], []),
|
||||
):
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks, filters={"tags": {cond: {op: tags}}}
|
||||
).tasks
|
||||
self.assertEqual({t.id for t in res}, set(expected_tasks))
|
||||
|
||||
def test_list_filters(self):
|
||||
tags = ["a", "b", "c", "d"]
|
||||
tasks = [self._temp_task(tags=tags[:i]) for i in range(len(tags) + 1)]
|
||||
|
||||
@@ -37,29 +37,44 @@ class TestPipelines(TestService):
|
||||
|
||||
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
|
||||
pipeline_task = res.pipeline
|
||||
try:
|
||||
self.assertTrue(res.enqueued)
|
||||
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
|
||||
self.assertTrue(pipeline.name.startswith(task_name))
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
)
|
||||
finally:
|
||||
self.api.tasks.delete(task=pipeline_task, force=True)
|
||||
self.assertTrue(res.enqueued)
|
||||
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
|
||||
self.assertTrue(pipeline.name.startswith(task_name))
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
)
|
||||
|
||||
# watched queue
|
||||
queue = self._temp_queue("test pipelines")
|
||||
project, task = self._temp_project_and_task(name="pipelines test1")
|
||||
res = self.api.pipelines.start_pipeline(
|
||||
task=task, queue=queue, verify_watched_queue=True
|
||||
)
|
||||
self.assertEqual(res.queue_watched, False)
|
||||
|
||||
self.api.workers.register(worker="test pipelines", queues=[queue])
|
||||
project, task = self._temp_project_and_task(name="pipelines test2")
|
||||
res = self.api.pipelines.start_pipeline(
|
||||
task=task, queue=queue, verify_watched_queue=True
|
||||
)
|
||||
self.assertEqual(res.queue_watched, True)
|
||||
|
||||
def _temp_project_and_task(self, name) -> Tuple[str, str]:
|
||||
project = self.create_temp(
|
||||
"projects", name=name, description="test", delete_params=dict(force=True),
|
||||
"projects",
|
||||
name=name,
|
||||
description="test",
|
||||
delete_params=dict(force=True, delete_contents=True),
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -72,3 +87,6 @@ class TestPipelines(TestService):
|
||||
system_tags=["pipeline"],
|
||||
),
|
||||
)
|
||||
|
||||
def _temp_queue(self, queue_name, **kwargs):
|
||||
return self.create_temp("queues", name=queue_name, **kwargs)
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestProjectTags(TestService):
|
||||
new_tags = ["New model tag"]
|
||||
self.api.models.update_tags(ids=[model], add_tags=new_tags)
|
||||
data = self.api.projects.get_model_tags(projects=[p])
|
||||
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
|
||||
self.assertEqual(set(data.tags), {*new_tags, *initial_tags})
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
|
||||
@@ -16,10 +16,18 @@ class TestTaskEvents(TestService):
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
default_task_name = "test task events"
|
||||
|
||||
def _temp_task(self, name=default_task_name):
|
||||
task_input = dict(name=name, type="training",)
|
||||
def _temp_project(self, name=default_task_name):
|
||||
return self.create_temp(
|
||||
"tasks", delete_paramse=self.delete_params, **task_input
|
||||
"projects",
|
||||
name=name,
|
||||
description="test",
|
||||
delete_params=self.delete_params,
|
||||
)
|
||||
|
||||
def _temp_task(self, name=default_task_name, **kwargs):
|
||||
self.update_missing(kwargs, name=name, type="training")
|
||||
return self.create_temp(
|
||||
"tasks", delete_paramse=self.delete_params, **kwargs
|
||||
)
|
||||
|
||||
def _temp_model(self, name="test model events", **kwargs):
|
||||
@@ -62,6 +70,26 @@ class TestTaskEvents(TestService):
|
||||
self._assert_task_metrics(tasks, "log")
|
||||
self._assert_task_metrics(tasks, "training_stats_scalar")
|
||||
|
||||
self._assert_multitask_metrics(
|
||||
tasks=list(tasks), metrics=["Metric1", "Metric2", "Metric3"]
|
||||
)
|
||||
self._assert_multitask_metrics(
|
||||
tasks=list(tasks),
|
||||
event_type="training_debug_image",
|
||||
metrics=["Metric1", "Metric2", "Metric3"],
|
||||
)
|
||||
self._assert_multitask_metrics(tasks=list(tasks), event_type="plot", metrics=[])
|
||||
|
||||
def _assert_multitask_metrics(
|
||||
self, tasks: Sequence[str], metrics: Sequence[str], event_type: str = None
|
||||
):
|
||||
res = self.api.events.get_multi_task_metrics(
|
||||
tasks=tasks,
|
||||
**({"event_type": event_type} if event_type else {}),
|
||||
).metrics
|
||||
self.assertEqual([r.metric for r in res], metrics)
|
||||
self.assertTrue(all(r.variants == ["Test variant"] for r in res))
|
||||
|
||||
def _assert_task_metrics(self, tasks: dict, event_type: str):
|
||||
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
|
||||
for task, metrics in tasks.items():
|
||||
@@ -122,6 +150,15 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(value.metric, metric)
|
||||
self.assertEqual(value.variant, variant)
|
||||
self.assertEqual(value.value, 0)
|
||||
# test metrics parameter
|
||||
res = self.api.events.get_task_single_value_metrics(
|
||||
tasks=[task], metrics=[{"metric": metric, "variants": [variant]}]
|
||||
).tasks
|
||||
self.assertEqual(len(res), 1)
|
||||
res = self.api.events.get_task_single_value_metrics(
|
||||
tasks=[task], metrics=[{"metric": "non_existing", "variants": [variant]}]
|
||||
).tasks
|
||||
self.assertEqual(len(res), 0)
|
||||
|
||||
# update is working
|
||||
task_data = self.api.tasks.get_by_id(task=task).task
|
||||
@@ -156,33 +193,33 @@ class TestTaskEvents(TestService):
|
||||
|
||||
def test_last_scalar_metrics(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
# send 2 batches to check the interaction with already stored db value
|
||||
# each batch contains multiple iterations
|
||||
self.send_batch(events[:50])
|
||||
self.send_batch(events[50:])
|
||||
for variant in ("Variant1", None):
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
# send 2 batches to check the interaction with already stored db value
|
||||
# each batch contains multiple iterations
|
||||
self.send_batch(events[:50])
|
||||
self.send_batch(events[50:])
|
||||
|
||||
task_data = self.api.tasks.get_by_id(task=task).task
|
||||
metric_data = first(first(task_data.last_metrics.values()).values())
|
||||
self.assertEqual(iter_count - 1, metric_data.value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
self.assertEqual(0, metric_data.min_value_iteration)
|
||||
task_data = self.api.tasks.get_by_id(task=task).task
|
||||
metric_data = first(first(task_data.last_metrics.values()).values())
|
||||
self.assertEqual(iter_count - 1, metric_data.value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
self.assertEqual(0, metric_data.min_value_iteration)
|
||||
|
||||
res = self.api.events.get_task_latest_scalar_values(task=task)
|
||||
self.assertEqual(iter_count - 1, res.last_iter)
|
||||
res = self.api.events.get_task_latest_scalar_values(task=task)
|
||||
self.assertEqual(iter_count - 1, res.last_iter)
|
||||
|
||||
def test_model_events(self):
|
||||
model = self._temp_model(ready=False)
|
||||
@@ -248,6 +285,15 @@ class TestTaskEvents(TestService):
|
||||
|
||||
self._assert_log_events(task=task, expected_total=1)
|
||||
|
||||
metrics = self.api.events.get_multi_task_metrics(
|
||||
tasks=[model],
|
||||
event_type="training_stats_scalar",
|
||||
model_events=True,
|
||||
).metrics
|
||||
self.assertEqual([m.metric for m in metrics], [f"Metric{i}" for i in range(5)])
|
||||
variants = [f"Variant{i}" for i in range(5)]
|
||||
self.assertTrue(all(m.variants == variants for m in metrics))
|
||||
|
||||
def test_error_events(self):
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
@@ -340,6 +386,30 @@ class TestTaskEvents(TestService):
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
def test_task_unique_metric_variants(self):
|
||||
project = self._temp_project()
|
||||
task1 = self._temp_task(project=project)
|
||||
task2 = self._temp_task(project=project)
|
||||
metric1 = "Metric1"
|
||||
metric2 = "Metric2"
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, 0),
|
||||
"metric": metric,
|
||||
"variant": "Variant",
|
||||
"value": 10,
|
||||
}
|
||||
for task, metric in ((task1, metric1), (task2, metric2))
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
metrics = self.api.projects.get_unique_metric_variants(project=project).metrics
|
||||
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
|
||||
metrics = self.api.projects.get_unique_metric_variants(ids=[task1, task2]).metrics
|
||||
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
|
||||
metrics = self.api.projects.get_unique_metric_variants(ids=[task1]).metrics
|
||||
self.assertEqual([m.metric for m in metrics], [metric1])
|
||||
|
||||
def test_task_metric_value_intervals_keys(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
@@ -395,6 +465,25 @@ class TestTaskEvents(TestService):
|
||||
iterations=iter_count,
|
||||
)
|
||||
|
||||
# test metrics
|
||||
data = self.api.events.multi_task_scalar_metrics_iter_histogram(
|
||||
tasks=tasks,
|
||||
metrics=[
|
||||
{
|
||||
"metric": f"Metric{m_idx}",
|
||||
"variants": [f"Variant{v_idx}" for v_idx in range(4)],
|
||||
}
|
||||
for m_idx in range(2)
|
||||
],
|
||||
)
|
||||
self._assert_metrics_and_variants(
|
||||
data.metrics,
|
||||
metrics=2,
|
||||
variants=4,
|
||||
tasks=tasks,
|
||||
iterations=iter_count,
|
||||
)
|
||||
|
||||
def _assert_metrics_and_variants(
|
||||
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
|
||||
):
|
||||
@@ -515,6 +604,13 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
|
||||
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
|
||||
|
||||
# test metrics
|
||||
plots = self.api.events.get_multi_task_plots(
|
||||
tasks=[task1, task2], metrics=[{"metric": "A"}]
|
||||
).plots
|
||||
self.assertEqual(len(plots), 1)
|
||||
self.assertEqual(len(plots.A), 2)
|
||||
|
||||
def test_task_plots(self):
|
||||
task = self._temp_task()
|
||||
event = self._create_task_event("plot", task, 0)
|
||||
|
||||
@@ -7,13 +7,17 @@ from humanfriendly import parse_timespan
|
||||
|
||||
def setup():
|
||||
from apiserver.database import db
|
||||
|
||||
db.initialize()
|
||||
|
||||
|
||||
def gen_token(args):
|
||||
from apiserver.bll.auth import AuthBLL
|
||||
resp = AuthBLL.get_token_for_user(args.user_id, args.company_id, parse_timespan(args.expiration))
|
||||
print('Token:\n%s' % resp.token)
|
||||
|
||||
resp = AuthBLL.get_token_for_user(
|
||||
args.user_id, args.company_id, int(parse_timespan(args.expiration))
|
||||
)
|
||||
print("Token:\n%s" % resp.token)
|
||||
|
||||
|
||||
def safe_get(obj, glob, default=None, separator="/"):
|
||||
@@ -23,19 +27,24 @@ def safe_get(obj, glob, default=None, separator="/"):
|
||||
return default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
top_parser = ArgumentParser(__doc__)
|
||||
|
||||
subparsers = top_parser.add_subparsers(title='Sections')
|
||||
subparsers = top_parser.add_subparsers(title="Sections")
|
||||
|
||||
token = subparsers.add_parser('token')
|
||||
token_commands = token.add_subparsers(title='Commands')
|
||||
token_create = token_commands.add_parser('generate', description='Generate a new token')
|
||||
token_create.add_argument('--user-id', '-u', help='User ID', required=True)
|
||||
token_create.add_argument('--company-id', '-c', help='Company ID', required=True)
|
||||
token_create.add_argument('--expiration', '-exp',
|
||||
help="Token expiration (time span, shorthand suffixes are supported, default 1m)",
|
||||
default=parse_timespan('1m'))
|
||||
token = subparsers.add_parser("token")
|
||||
token_commands = token.add_subparsers(title="Commands")
|
||||
token_create = token_commands.add_parser(
|
||||
"generate", description="Generate a new token"
|
||||
)
|
||||
token_create.add_argument("--user-id", "-u", help="User ID", required=True)
|
||||
token_create.add_argument("--company-id", "-c", help="Company ID", required=True)
|
||||
token_create.add_argument(
|
||||
"--expiration",
|
||||
"-exp",
|
||||
help="Token expiration (time span, shorthand suffixes are supported, default 1m)",
|
||||
default=parse_timespan("1m"),
|
||||
)
|
||||
token_create.set_defaults(_func=gen_token)
|
||||
|
||||
args = top_parser.parse_args()
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.13.0"
|
||||
__version__ = "1.15.0"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM node:18-bullseye as webapp_builder
|
||||
FROM node:20-bookworm-slim as webapp_builder
|
||||
|
||||
ARG CLEARML_WEB_GIT_URL=https://github.com/allegroai/clearml-web.git
|
||||
|
||||
@@ -10,8 +10,9 @@ RUN mv clearml-web /opt/open-webapp
|
||||
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
|
||||
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
|
||||
|
||||
FROM python:3.9-slim-bullseye
|
||||
FROM python:3.9-slim-bookworm
|
||||
COPY --chmod=744 docker/build/internal_files/entrypoint.sh /opt/clearml/
|
||||
COPY --chmod=744 docker/build/internal_files/update_from_env.py /opt/clearml/utilities/
|
||||
COPY fileserver /opt/clearml/fileserver/
|
||||
COPY apiserver /opt/clearml/apiserver/
|
||||
|
||||
|
||||
@@ -29,7 +29,12 @@ server {
|
||||
include /etc/nginx/default.d/*.conf;
|
||||
|
||||
location / {
|
||||
try_files $uri$args $uri$args/ $uri index.html /index.html;
|
||||
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always;
|
||||
add_header Content-Security-Policy "frame-ancestors 'self'";
|
||||
add_header X-XSS-Protection "1; mode=block";
|
||||
add_header X-Content-Type-Options "nosniff" always;
|
||||
add_header Referrer-Policy "no-referrer-when-downgrade";
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
location /version.json {
|
||||
@@ -50,6 +55,12 @@ server {
|
||||
rewrite /files/(.*) /$1 break;
|
||||
}
|
||||
|
||||
location /widgets {
|
||||
alias /usr/share/nginx/widgets;
|
||||
try_files $uri $uri/ /widgets/index.html;
|
||||
add_header Content-Security-Policy "frame-ancestors *";
|
||||
}
|
||||
|
||||
error_page 404 /404.html;
|
||||
location = /40x.html {
|
||||
}
|
||||
@@ -57,4 +68,4 @@ server {
|
||||
error_page 500 502 503 504 /50x.html;
|
||||
location = /50x.html {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,10 +46,26 @@ elif [[ ${SERVER_TYPE} == "webserver" ]]; then
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Create an empty configuration json
|
||||
echo "{}" > /tmp/configuration.json
|
||||
|
||||
# Copy the external configuration file if it exists
|
||||
if test -f "/mnt/external_files/configs/configuration.json"; then
|
||||
echo "Copying external configuration"
|
||||
cp /mnt/external_files/configs/configuration.json /tmp/configuration.json
|
||||
fi
|
||||
|
||||
# Update from env variables
|
||||
echo "Updating configuration from env"
|
||||
/opt/clearml/utilities/update_from_env.py \
|
||||
--verbose \
|
||||
/tmp/configuration.json \
|
||||
/usr/share/nginx/html/configuration.json
|
||||
|
||||
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
|
||||
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
|
||||
COMMENT_IPV6_LISTEN=$([ "$DISABLE_NGINX_IPV6" = "true" ] && echo "#" || echo "") \
|
||||
envsubst '${COMMENT_IPV6_LISTEN} ${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/sites-enabled/default
|
||||
export COMMENT_IPV6_LISTEN=$([ "$DISABLE_NGINX_IPV6" = "true" ] && echo "#" || echo "")
|
||||
envsubst '${COMMENT_IPV6_LISTEN} ${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/sites-enabled/default
|
||||
|
||||
if [[ -n "${CLEARML_SERVER_SUB_PATH}" ]]; then
|
||||
mkdir -p /etc/nginx/default.d/
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
set -o errexit
|
||||
set -o nounset
|
||||
set -o pipefail
|
||||
|
||||
apt-get update -y
|
||||
apt-get install -y python3-setuptools python3-dev build-essential nginx gettext
|
||||
apt-get install -y vim curl
|
||||
apt-get install -y python3-setuptools python3-dev build-essential nginx gettext vim curl
|
||||
|
||||
python3 -m ensurepip
|
||||
python3 -m pip install --upgrade pip
|
||||
|
||||
104
docker/build/internal_files/update_from_env.py
Normal file
104
docker/build/internal_files/update_from_env.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
""" Update json configuration file from environment variables """
|
||||
from argparse import ArgumentParser, FileType
|
||||
import json
|
||||
from os import environ
|
||||
from typing import Any, Generator, Tuple, Optional, List
|
||||
|
||||
|
||||
class PathConflictError(Exception):
|
||||
def __init__(self, path_: List[str]):
|
||||
self.path = path_
|
||||
|
||||
|
||||
def scan(
|
||||
obj: Any, path_: str = None, sep: str = ".", parent_=None, key_=None,
|
||||
) -> Generator[Tuple[str, Any, Optional[dict], str], None, None]:
|
||||
if not isinstance(obj, dict):
|
||||
yield path_.lower(), obj, parent_, key_
|
||||
else:
|
||||
for k, v in obj.items():
|
||||
yield from scan(v, path_=sep.join(filter(None, (path_, k))), parent_=obj, key_=k, sep=sep)
|
||||
|
||||
|
||||
def set_path(p: List[str], obj: dict, v: Any):
|
||||
key_, *rest = p
|
||||
if not rest:
|
||||
obj[key_] = v
|
||||
else:
|
||||
if key_ in obj:
|
||||
if not isinstance(obj[key_], dict):
|
||||
raise PathConflictError(rest)
|
||||
else:
|
||||
obj[key_] = {}
|
||||
return set_path(rest, obj[key_], v)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
parser.add_argument("input_file", type=FileType(), help="Input JSON file")
|
||||
parser.add_argument("output_file", type=FileType("w"), help="Output JSON file")
|
||||
parser.add_argument(
|
||||
"--env-prefix", "-p", default="WEBSERVER", help="Environment variables prefix (default=%(default)s)",
|
||||
dest="prefix", required=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env-separator", "-s", default="__", help="Environment variable name separator (default=%(default)s)",
|
||||
dest="sep"
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--disable-parse-env-value", action="store_false", default=True, help="Don't parse env value as JSON",
|
||||
dest="parse_env"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.prefix:
|
||||
print("Error: script does not support an empty prefix")
|
||||
exit(1)
|
||||
|
||||
data = None
|
||||
try:
|
||||
data = json.load(args.input_file)
|
||||
except json.JSONDecodeError as ex:
|
||||
print(f"Error parsing JSON file {args.input_file.name}: {str(ex)}")
|
||||
exit(1)
|
||||
|
||||
def parse_value(k, v):
|
||||
try:
|
||||
return json.loads(v)
|
||||
except json.JSONDecodeError as ex:
|
||||
print(f"Error parsing {k} JSON value `{v}`: {str(ex)}")
|
||||
exit(2)
|
||||
|
||||
prefix = args.prefix + args.sep
|
||||
|
||||
env_vars = {
|
||||
k.lstrip(prefix): parse_value(k, v) if args.parse_env else v
|
||||
for k, v in environ.items() if k.startswith(prefix)
|
||||
}
|
||||
|
||||
for path, value, parent, key in scan(data, sep=args.sep):
|
||||
if not (parent and key):
|
||||
continue
|
||||
|
||||
match = next((k for k in env_vars if k.lower() == path), None)
|
||||
if match:
|
||||
replace = env_vars.pop(match)
|
||||
parent[key] = replace
|
||||
if args.verbose:
|
||||
print(f"Replacing {path}={value} with {replace}")
|
||||
|
||||
for k, v in env_vars.items():
|
||||
path = k.split(args.sep)
|
||||
try:
|
||||
set_path(path, data, v)
|
||||
except PathConflictError as ex:
|
||||
print(f"Error: failed setting value into {k}: {path[:-len(ex.path)]} is not a dictionary")
|
||||
|
||||
try:
|
||||
json.dump(data, args.output_file, sort_keys=True, indent=2)
|
||||
except Exception as ex:
|
||||
print(f"Error writing JSON file {args.output_file.name}: {str(ex)}")
|
||||
exit(3)
|
||||
@@ -49,13 +49,10 @@ services:
|
||||
cluster.routing.allocation.disk.watermark.low: 500mb
|
||||
cluster.routing.allocation.disk.watermark.high: 500mb
|
||||
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: clearml
|
||||
reindex.remote.whitelist: '*.*'
|
||||
xpack.monitoring.enabled: "false"
|
||||
reindex.remote.whitelist: "'*.*'"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
memlock:
|
||||
@@ -64,7 +61,7 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
|
||||
@@ -93,7 +90,7 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:4.4.9
|
||||
image: mongo:4.4.29
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
|
||||
volumes:
|
||||
@@ -104,7 +101,7 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-redis
|
||||
image: redis:5.0
|
||||
image: redis:6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/clearml/data/redis:/data
|
||||
|
||||
@@ -49,13 +49,10 @@ services:
|
||||
cluster.routing.allocation.disk.watermark.low: 500mb
|
||||
cluster.routing.allocation.disk.watermark.high: 500mb
|
||||
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: clearml
|
||||
reindex.remote.whitelist: '*.*'
|
||||
xpack.monitoring.enabled: "false"
|
||||
reindex.remote.whitelist: "'*.*'"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
memlock:
|
||||
@@ -64,7 +61,7 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
|
||||
@@ -92,7 +89,7 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:4.4.9
|
||||
image: mongo:4.4.29
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
|
||||
volumes:
|
||||
@@ -103,7 +100,7 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-redis
|
||||
image: redis:5.0
|
||||
image: redis:6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/clearml/data/redis:/data
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
boltons>=19.1.0
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
flask>=2.3.2
|
||||
flask>=2.3.3
|
||||
gunicorn>=20.1.0
|
||||
pyhocon>=0.3.35
|
||||
setuptools>=65.5.1
|
||||
urllib3>=1.26.18
|
||||
urllib3>=1.26.18
|
||||
werkzeug>=3.0.1
|
||||
Reference in New Issue
Block a user