Compare commits

24 Commits
1.0.0 ... 1.1.0

Author SHA1 Message Date
allegroai
c5d6ce3e65 Version bump 2021-07-25 14:40:57 +03:00
allegroai
694dbc31c4 Fix incorrect ES query (merge issue) 2021-07-25 14:40:49 +03:00
allegroai
6488dc54e6 Better handling of stack trace report on 500 error 2021-07-25 14:39:59 +03:00
allegroai
158da9b480 Allow setting status_message in tasks.update
Optimizations and refactoring
2021-07-25 14:35:36 +03:00
allegroai
ec2e071ab7 Fix mongoengine cannot handle field name with leading or trailing "_" when used in fields query within get_all endpoints 2021-07-25 14:34:04 +03:00
allegroai
465e270342 Fix queued task is not dequeued on tasks.stop 2021-07-25 14:32:09 +03:00
allegroai
6705aff56f Allow requesting plots and iter_histograms for all variants 2021-07-25 14:30:38 +03:00
allegroai
9069cfe1da Support querying task events per specific metrics and variants 2021-07-25 14:29:41 +03:00
allegroai
677bb3ba6d Add force parameter to tasks.enqueue 2021-07-25 14:27:46 +03:00
allegroai
cb253cff9e Don't use special characters in secrets 2021-07-25 14:26:49 +03:00
allegroai
39ceb5ac5c Fix pre-populate logic to avoid overriding existing users 2021-07-25 14:26:31 +03:00
allegroai
d4edeaaf1b Add projects.validate_delete 2021-07-25 14:17:29 +03:00
allegroai
56aea1ffb8 Fix filtering on hyperparams (https://github.com/allegroai/clearml/issues/385, https://clearml.slack.com/archives/CTK20V944/p1626600582284700) 2021-07-25 13:55:09 +03:00
allegroai
09ab2af34c Version bump 2021-05-27 17:13:19 +03:00
allegroai
8bb26a6b0b Fix fileserver depends on deprecated flask._compat.fspath and safe_join 2021-05-27 17:13:02 +03:00
allegroai
3f2304549d Move new migrations to 1_0_2 2021-05-27 16:56:47 +03:00
allegroai
ad72a435f1 Clean Task runtime on reset 2021-05-27 16:56:03 +03:00
allegroai
f34332344e Fix Task container raises validation error on null values 2021-05-27 16:55:32 +03:00
allegroai
d324b57dd7 Fix bad error message format 2021-05-27 16:55:00 +03:00
allegroai
2216bfe875 Version bump 2021-05-11 16:12:48 +03:00
allegroai
9beefa7473 Add missing login.logout endpoint 2021-05-11 16:12:27 +03:00
allegroai
8ebc334889 Fix broken config dir backwards compatibility (/opt/trains/config should still be supported) 2021-05-11 16:12:13 +03:00
allegroai
e662c850af Update config file in docs 2021-05-04 11:07:38 +03:00
allegroai
1e5163e530 Upgrade jinja2 version due to CVE-2020-28493 2021-05-03 23:23:06 +03:00
37 changed files with 633 additions and 284 deletions

View File

@@ -14,12 +14,18 @@ from apiserver.utilities.stringenum import StringEnum
class HistogramRequestBase(Base):
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
class MetricVariants(Base):
metric: str = StringField(required=True)
variants: Sequence[str] = ListField(items_types=str)
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@@ -39,6 +45,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(default=None)
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
@@ -59,8 +66,8 @@ class TaskMetricVariant(Base):
class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
class NextDebugImageSampleRequest(Base):
@@ -102,3 +109,10 @@ class TaskMetricsRequest(Base):
items_types=str, validators=[Length(minimum_value=1)]
)
event_type: EventType = ActualEnumField(EventType, required=True)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)

View File

@@ -57,6 +57,7 @@ class AuthBLL:
api_version=str(ServiceRepo.max_endpoint_version()),
server_version=str(get_version()),
server_build=str(get_build_number()),
feature_set="basic",
)
return GetTokenResponse(token=token.decode("ascii"))

View File

@@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Set
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
@@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
@@ -74,7 +75,7 @@ class DebugImagesIterator:
def get_task_events(
self,
company_id: str,
task_metrics: Mapping[str, Set[str]],
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
@@ -118,7 +119,7 @@ class DebugImagesIterator:
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]],
task_metrics: Mapping[str, dict],
):
"""
Determine the metrics for which new debug image events were added
@@ -158,11 +159,11 @@ class DebugImagesIterator:
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = set(
m
metrics_to_recalc = {
m: task_metrics[task].get(m)
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
)
}
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
@@ -196,7 +197,7 @@ class DebugImagesIterator:
]
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, Set[str]]
self, company_id: str, task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
@@ -213,7 +214,7 @@ class DebugImagesIterator:
]
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Set[str]], company_id: str
self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
@@ -222,10 +223,11 @@ class DebugImagesIterator:
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append({"terms": {"metric": list(metrics)}})
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"query": query,
"aggs": {
"metrics": {
"terms": {

View File

@@ -6,9 +6,8 @@ from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, Dict
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import six
from elasticsearch import helpers
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
@@ -22,6 +21,8 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
delete_company_events,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
@@ -43,8 +44,8 @@ from apiserver.utilities.json import loads
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2**63 - 1
MIN_LONG = -2**63
MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
class PlotFields:
@@ -94,7 +95,7 @@ class EventBLL(object):
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
actions: List[dict] = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
@@ -197,7 +198,6 @@ class EventBLL(object):
actions.append(es_action)
action: Dict[dict]
plot_actions = [
action["_source"]
for action in actions
@@ -260,7 +260,8 @@ class EventBLL(object):
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
)
if not added:
@@ -466,10 +467,16 @@ class EventBLL(object):
task_id: str,
num_last_iterations: int,
event_type: EventType,
metric_variants: MetricVariants = None,
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"aggs": {
@@ -499,7 +506,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@@ -527,6 +534,7 @@ class EventBLL(object):
sort=None,
size: int = 500,
scroll_id: str = None,
metric_variants: MetricVariants = None,
):
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@@ -555,6 +563,8 @@ class EventBLL(object):
if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else:
should = []
for i, task_id in enumerate(tasks):
@@ -563,6 +573,7 @@ class EventBLL(object):
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
metric_variants=metric_variants,
)
if not last_iters:
continue
@@ -669,19 +680,19 @@ class EventBLL(object):
sort=None,
size=500,
scroll_id=None,
):
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = []
if metric:
must.append({"term": {"metric": metric}})
@@ -691,26 +702,24 @@ class EventBLL(object):
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
should = []
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_id,
iters=last_iter_count,
)
if not last_iters:
continue
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"iter": last_iters}},
]
}
tasks_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
}
)
}
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
@@ -748,6 +757,7 @@ class EventBLL(object):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
es_req = {
"size": 0,
"aggs": {
@@ -768,7 +778,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@@ -787,21 +797,24 @@ class EventBLL(object):
return metrics
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
def get_task_latest_scalar_values(
self, company_id, task_id
) -> Tuple[Sequence[dict], int]:
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
return [], 0
query = {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
}
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
},
"query": query,
"aggs": {
"metrics": {
"terms": {
@@ -905,34 +918,47 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters(
self, company_id: str, event_type: EventType, task_id: str, iters: int
):
self,
company_id: str,
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
es_req: dict = {
"size": 0,
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
"tasks": {
"terms": {"field": "task"},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
}
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
}
with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
if "aggregations" not in es_res:
return []
return {}
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
return {
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
@@ -965,7 +991,9 @@ class EventBLL(object):
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
es_res = delete_company_events(
es=self.es,
company_id=company_id,

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Union, Sequence
from typing import Union, Sequence, Mapping
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
@@ -16,6 +16,9 @@ class EventType(Enum):
all = "*"
MetricVariants = Mapping[str, Sequence[str]]
class EventSettings:
@classproperty
def max_workers(self):
@@ -64,3 +67,23 @@ def delete_company_events(
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs)
def get_metric_variants_condition(
metric_variants: MetricVariants,
) -> Sequence:
conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"terms": {"variant": variants}},
]
}
}
if variants
else {"term": {"metric": metric}}
for metric, variants in metric_variants.items()
]
return {"bool": {"should": conditions}}

View File

@@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
EventSettings,
search_company_events,
check_empty_data,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
@@ -34,7 +36,12 @@ class EventMetrics:
self.es = es
def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
self,
company_id: str,
task_id: str,
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@@ -46,7 +53,12 @@ class EventMetrics:
return {}
return self._get_scalar_average_per_iter_core(
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
task_id=task_id,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
)
def _get_scalar_average_per_iter_core(
@@ -57,6 +69,7 @@ class EventMetrics:
samples: int,
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@@ -64,6 +77,7 @@ class EventMetrics:
task_id=task_id,
samples=samples,
field=key.field,
metric_variants=metric_variants,
)
if not intervals:
return {}
@@ -197,6 +211,7 @@ class EventMetrics:
task_id: str,
samples: int,
field: str = "iter",
metric_variants: MetricVariants = None,
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
@@ -204,9 +219,14 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req = {
"size": 0,
"query": {"term": {"task": task_id}},
"query": query,
"aggs": {
"metrics": {
"terms": {

View File

@@ -554,7 +554,7 @@ class ProjectBLL:
user_ids: Optional[Sequence[str]] = None,
) -> Set[str]:
"""
Get the set of user ids that created tasks/models/dataviews in the given projects
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
@@ -676,8 +676,8 @@ class ProjectBLL:
@classmethod
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
"""
Returns the amount of task/dataviews/models per requested project
Use separate aggregation calls on Task/Dataview/Model instead of lookup
Returns the amount of task/models per requested project
Use separate aggregation calls on Task/Model instead of lookup
aggregation on projects in order not to hit memory limits on large tasks
"""
if not project_ids:

View File

@@ -30,6 +30,28 @@ class DeleteProjectResult:
urls: TaskUrls = None
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
).count()
for cls in (Task, Model):
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).count()
return ret
def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
) -> Tuple[DeleteProjectResult, Set[str]]:

View File

@@ -1,11 +1,10 @@
import itertools
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Optional
import dpath
from apiserver.apierrors import errors
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
"""
Return parameter section and name. The section is either TF_DEFINE or the default one
"""
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
return []
if with_sections:
return itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
return list(
itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
)
)
return [
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
Escape all the section and param names for hyper params and configuration to make it mongo sage
"""
for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
(("execution", "model_desc"), "configuration", None),
):
legacy_params = safe_get(fields, old_params_field)
legacy_params = nested_get(fields, old_params_field)
if legacy_params is None:
continue
if (
not safe_get(fields, new_params_field)
not fields.get(new_params_field)
and previous_task
and previous_task[new_params_field]
):
@@ -117,11 +118,11 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None:
new_param["section"] = section
dpath.new(fields, new_path, new_param)
dpath.delete(fields, old_params_field)
nested_set(fields, new_path, new_param)
nested_delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
params = fields.get(param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
@@ -131,7 +132,7 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
else value
for key, value in params.items()
}
dpath.set(fields, param_field, escaped_params)
fields[param_field] = escaped_params
def params_unprepare_from_saved(fields, copy_to_legacy=False):
@@ -140,7 +141,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
"""
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
params = fields.get(param_field)
if params:
unescaped_params = {
ParameterKeyEscaper.unescape(key): {
@@ -150,18 +151,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
else value
for key, value in params.items()
}
dpath.set(fields, param_field, unescaped_params)
fields[param_field] = unescaped_params
if copy_to_legacy:
for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True),
(f"configuration", "execution/model_desc", False),
("hyperparams", ("execution", "parameters"), True),
("configuration", ("execution", "model_desc"), False),
):
legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections
fields.get(new_params_field), with_sections=use_sections
)
if legacy_params:
dpath.new(
nested_set(
fields,
old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params},
@@ -174,7 +175,7 @@ def _process_path(path: str):
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
if len(parts) < 2 or len(parts) > 4:
raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
@@ -184,7 +185,7 @@ def _process_path(path: str):
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
("execution.model_desc", "configuration"),
("execution.docker_cmd", "container")
):
path: str

View File

@@ -130,14 +130,14 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
if not metrics:
return set()
task_metrics = {task: set(metrics)}
task_metrics = {task: {m: [] for m in metrics}}
scroll_id = None
urls = set()
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company,
task_metrics=task_metrics,
iter_count=100,
iter_count=10,
state_id=scroll_id,
)
if not res.metric_events or not any(

View File

@@ -109,6 +109,7 @@ def enqueue_task(
status_message: str,
status_reason: str,
validate: bool = False,
force: bool = False,
) -> Tuple[int, dict]:
if not queue_id:
# try to get default queue
@@ -128,6 +129,7 @@ def enqueue_task(
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
force=force,
).execute(enqueue_status=task.status)
try:
@@ -238,6 +240,7 @@ def reset_task(
set__last_metrics={},
set__metric_stats={},
set__models__output=[],
set__runtime={},
unset__output__result=1,
unset__output__error=1,
unset__last_worker=1,
@@ -364,7 +367,21 @@ def stop_task(
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
is_queued = task.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task.system_tags
or not is_run_by_worker(task)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:

View File

@@ -19,7 +19,7 @@ from pyparsing import (
from apiserver.utilities import json
EXTRA_CONFIG_PATHS = ("/opt/clearml/config",)
EXTRA_CONFIG_PATHS = ("/opt/trains/config", "/opt/clearml/config")
DEFAULT_PREFIXES = ("clearml", "trains")
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"

View File

@@ -3,7 +3,7 @@
debug: false # Debug mode
pretty_json: false # prettify json response
return_stack: true # return stack trace on error
log_calls: true # Log API Calls
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
# valid values are:

View File

@@ -176,6 +176,13 @@ class SafeMapField(MapField, DictValidationMixin):
self.error("Empty keys are not allowed in a MapField")
class NullableStringField(StringField):
def validate(self, value):
if value is None:
return
super(NullableStringField, self).validate(value)
class SafeDictField(DictField, DictValidationMixin):
def validate(self, value):
self._safe_validate(value)

View File

@@ -117,7 +117,7 @@ class GetMixin(PropsMixin):
def __init__(self, legacy=False):
self._legacy = legacy
def key(self, v):
def key(self, v) -> Optional[str]:
if v is None:
self._next = self._default
return self._default
@@ -133,6 +133,7 @@ class GetMixin(PropsMixin):
next_ = self._next
if not self._sticky:
self._next = self._default
return next_
def value_transform(self, v):
@@ -273,10 +274,13 @@ class GetMixin(PropsMixin):
).items():
query &= cls.get_range_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
if data is not None:
dict_query[field] = data
for field, data in cls._pop_matching_params(
patterns=opts.fields or [], parameters=parameters
).items():
if "._" in field or "_." in field:
query &= Q(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)

View File

@@ -18,6 +18,7 @@ from apiserver.database.fields import (
UnionField,
SafeSortedListField,
EmbeddedDocumentListField,
NullableStringField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@@ -218,6 +219,7 @@ class Task(AttributedDocument):
"status",
"project",
"parent",
"hyperparams.*",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
@@ -232,7 +234,7 @@ class Task(AttributedDocument):
type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField()
status_message = StringField()
status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
@@ -260,7 +262,7 @@ class Task(AttributedDocument):
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)
models: Models = EmbeddedDocumentField(Models, default=Models)
container = SafeMapField(field=StringField(default=""))
container = SafeMapField(field=NullableStringField())
enqueue_status = StringField(
choices=get_options(TaskStatus), exclude_by_default=True
)

View File

@@ -298,8 +298,9 @@ class PrePopulate:
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
cls.user_cls(id=user_id, name=user_name, company="").save()
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata)

View File

@@ -97,22 +97,6 @@ def _migrate_model_labels(db: Database):
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
def _migrate_project_description(db: Database):
projects: Collection = db["project"]
filter = {
"$or": [
{
"$expr": {"$lt": [{"$strLenCP": "$description"}, 100]},
"description": {"$regex": "^Auto-generated at ", "$options": "i"},
},
{"description": {"$regex": "^Auto-generated during move$", "$options": "i"}},
{"description": {"$regex": "^Auto-generated while cloning$", "$options": "i"}},
]
}
for doc in projects.find(filter=filter):
projects.update_one({"_id": doc["_id"]}, {"$unset": {"description": 1}})
def _migrate_project_names(db: Database):
projects: Collection = db["project"]
@@ -141,5 +125,4 @@ def migrate_backend(db: Database):
_migrate_docker_cmd(db)
_migrate_model_labels(db)
_migrate_project_names(db)
_migrate_project_description(db)
_drop_all_indices_from_collections(db, ["task*"])

View File

@@ -0,0 +1,22 @@
from pymongo.collection import Collection
from pymongo.database import Database
def _migrate_project_description(db: Database):
projects: Collection = db["project"]
filter = {
"$or": [
{
"$expr": {"$lt": [{"$strLenCP": "$description"}, 100]},
"description": {"$regex": "^Auto-generated at ", "$options": "i"},
},
{"description": {"$regex": "^Auto-generated during move$", "$options": "i"}},
{"description": {"$regex": "^Auto-generated while cloning$", "$options": "i"}},
]
}
for doc in projects.find(filter=filter):
projects.update_one({"_id": doc["_id"]}, {"$unset": {"description": 1}})
def migrate_backend(db: Database):
_migrate_project_description(db)

View File

@@ -12,7 +12,7 @@ funcsigs==1.0.2
furl>=2.0.0
gunicorn>=19.7.1
humanfriendly==4.18
jinja2==2.10.1
jinja2==2.11.3
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0

View File

@@ -1,6 +1,18 @@
{
_description : "Provides an API for running tasks to report events collected by the system."
_definitions {
metric_variants {
type: object
metric {
description: The metric name
type: string
}
variants {
type: array
description: The names of the metric variants
items {type: string}
}
}
metrics_scalar_event {
description: "Used for reporting scalar metrics during training task"
type: object
@@ -193,6 +205,29 @@
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
}
}
task_metric_variants {
type: object
required: [task]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
variants {
description: Metric variant names
type: array
items {type: string}
}
}
}
task_log_event {
@@ -376,7 +411,7 @@
metrics {
type: array
items { "$ref": "#/definitions/task_metric" }
description: "List metrics for which the envents will be retreived"
description: "List of task metrics for which the envents will be retreived"
}
iters {
type: integer
@@ -411,6 +446,17 @@
}
}
}
"2.14": ${debug_images."2.7"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/task_metric_variants" }
}
}
}
}
}
get_debug_image_sample {
"2.12": {
@@ -804,6 +850,17 @@
}
}
}
"2.14": ${get_task_plots."2.1"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
get_multi_task_plots {
"2.1" {
@@ -962,6 +1019,17 @@
}
}
}
"2.14": ${scalar_metrics_iter_histogram."2.1"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
multi_task_scalar_metrics_iter_histogram {
"2.1" {

View File

@@ -93,3 +93,19 @@ supported_modes {
}
}
}
logout {
authorize: false
allow_roles = [ "*" ]
"2.13" {
description: """ Logout (including SSO, if used)) """
request {
type: object
additionalProperties: false
}
response {
type: object
additionalProperties: false
}
}
}

View File

@@ -379,7 +379,7 @@ get_all {
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of dataviews"
description: "Page number, returns a specific page out of the resulting list of projects"
type: integer
minimum: 0
}
@@ -469,7 +469,7 @@ get_all_ex {
default: false
}
check_own_contents {
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks and models are counted"
type: boolean
default: false
}
@@ -594,7 +594,7 @@ merge {
type: object
properties {
moved_entities {
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
description: "The number of tasks and models moved from the merged project into the destination"
type: integer
}
moved_projects {
@@ -605,6 +605,42 @@ merge {
}
}
}
validate_delete {
"2.14" {
description: "Validates that the project existis and can be deleted"
request {
type: object
required: [ project ]
properties {
project {
description: "Project ID"
type: string
}
}
}
response {
type: object
properties {
tasks {
description: "The total number of tasks under the project and all its children"
type: integer
}
non_archived_tasks {
description: "The total number of non-archived tasks under the project and all its children"
type: integer
}
models {
description: "The total number of models under the project and all its children"
type: integer
}
non_archived_models {
description: "The total number of non-archived models under the project and all its children"
type: integer
}
}
}
}
}
delete {
"2.1" {
description: "Deletes a project"
@@ -613,7 +649,7 @@ delete {
required: [ project ]
properties {
project {
description: "Project id"
description: "Project ID"
type: string
}
force {

View File

@@ -588,13 +588,18 @@ class APICall(DataContainer):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
def get_response(self, include_stack: bool = None) -> Tuple[Union[dict, str], str]:
"""
Get the response for this call.
:param include_stack: If True, stack trace stored in this call's result should
be included in the response (default is False)
be included in the response (default follows configuration)
:return: Response data (encoded according to self.content_type) and the data's content type
"""
include_stack = (
include_stack
if include_stack is not None
else config.get("apiserver.return_stack_to_caller", False)
)
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
"""

View File

@@ -1,9 +1,12 @@
import random
import string
sys_random = random.SystemRandom()
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
def get_random_string(
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
) -> str:
"""
Returns a securely generated random string.
@@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
Taken from the django.utils.crypto module.
"""
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
def get_client_id(length=20):
def get_client_id(length: int = 20) -> str:
"""
Create a random secret key.
Taken from the Django project.
"""
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
chars = string.ascii_uppercase + string.digits
return get_random_string(length, chars)
def get_secret_key(length=50):
def get_secret_key(length: int = 50) -> str:
"""
Create a random secret key.
@@ -33,5 +36,5 @@ def get_secret_key(length=50):
NOTE: asterisk is not supported due to issues with environment variables containing
asterisks (in case the secret key is stored in an environment variable)
"""
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&(-_=+)'
chars = string.ascii_letters + string.digits
return get_random_string(length, chars)

View File

@@ -37,7 +37,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.13")
_max_version = PartialVersion("2.14")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -41,14 +41,12 @@ def login(call: APICall, *_, **__):
)
# Add authorization cookie
call.result.cookies[
config.get("apiserver.auth.session_auth_cookie_name")
] = call.result.data_model.token
call.result.set_auth_cookie(call.result.data_model.token)
@endpoint("auth.logout", min_version="2.2")
def logout(call: APICall, *_, **__):
call.result.cookies[config.get("apiserver.auth.session_auth_cookie_name")] = None
call.result.set_auth_cookie(None)
@endpoint(

View File

@@ -3,6 +3,7 @@ from collections import defaultdict
from operator import itemgetter
import attr
from typing import Sequence, Optional
from apiserver.apierrors import errors
from apiserver.apimodels.events import (
@@ -17,9 +18,11 @@ from apiserver.apimodels.events import (
LogOrderEnum,
GetDebugImageSampleRequest,
NextDebugImageSampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
@@ -321,7 +324,7 @@ def get_task_latest_scalar_values(call, company_id, _):
)
last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
)
).get(task_id)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
@@ -494,11 +497,22 @@ def get_task_plots_v1_7(call, company_id, _):
)
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
def _get_metric_variants_from_request(
req_metrics: Sequence[ApiMetrics],
) -> Optional[MetricVariants]:
if not req_metrics:
return None
return {m.metric: m.variants for m in req_metrics}
@endpoint(
"events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest
)
def get_task_plots(call, company_id, request: TaskPlotsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -509,6 +523,7 @@ def get_task_plots(call, company_id, _):
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
scroll_id=scroll_id,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
return_events = result.events
@@ -594,9 +609,9 @@ def get_debug_images_v1_8(call, company_id, _):
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_metrics = defaultdict(set)
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task].add(tm.metric)
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
@@ -734,11 +749,11 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
def _get_top_iter_unique_events(events, max_iters):
top_unique_events = defaultdict(lambda: [])
for e in events:
key = e.get("metric", "") + e.get("variant", "")
for ev in events:
key = ev.get("metric", "") + ev.get("variant", "")
evs = top_unique_events[key]
if len(evs) < max_iters:
evs.append(e)
evs.append(ev)
unique_events = list(
itertools.chain.from_iterable(list(top_unique_events.values()))
)

View File

@@ -1,5 +1,3 @@
from jsonmodels.fields import BoolField
from apiserver.apimodels.login import (
GetSupportedModesRequest,
GetSupportedModesResponse,
@@ -35,3 +33,8 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
),
authenticated=call.auth is not None,
)
@endpoint("login.logout", min_version="2.13")
def logout(call: APICall, _, __):
call.result.set_auth_cookie(None)

View File

@@ -16,10 +16,14 @@ from apiserver.apimodels.projects import (
MoveRequest,
MergeRequest,
ProjectOrNoneRequest,
ProjectRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project.project_cleanup import delete_project
from apiserver.bll.project.project_cleanup import (
delete_project,
validate_project_delete,
)
from apiserver.bll.task import TaskBLL
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
@@ -230,6 +234,13 @@ def merge(call: APICall, company: str, request: MergeRequest):
}
@endpoint("projects.validate_delete")
def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
call.result.data = validate_project_delete(
company=company_id, project_id=request.project
)
@endpoint("projects.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id: str, request: DeleteRequest):
res, affected_projects = delete_project(

View File

@@ -4,7 +4,6 @@ from functools import partial
from typing import Sequence, Union, Tuple
import attr
import dpath
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
@@ -220,14 +219,13 @@ def get_all_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
@@ -236,14 +234,13 @@ def get_by_id_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
@@ -252,16 +249,15 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
@@ -403,15 +399,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
escape_dict_field(fields, path)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_stripped_fields:
try:
path = f"script/{field}"
value = dpath.get(fields, path)
script = fields.get("script")
if script:
for field in task_script_stripped_fields:
value = script.get(field)
if isinstance(value, str):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
script[field] = value.strip()
return fields
@@ -546,10 +539,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
}
def prepare_update_fields(call: APICall, task, call_data):
def prepare_update_fields(call: APICall, call_data):
valid_fields = deepcopy(Task.user_set_allowed())
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
update_fields["output__error"] = None
update_fields.update(
status=None, status_reason=None, status_message=None, output__error=None
)
t_fields = task_fields
t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields)
@@ -569,7 +564,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
if not partial_update_dict:
return UpdateResponse(updated=0)
@@ -642,7 +637,7 @@ def update_batch(call: APICall, company_id, _):
updated_projects = set()
for id, data in items.items():
task = tasks[id]
fields, valid_fields = prepare_update_fields(call, task, data)
fields, valid_fields = prepare_update_fields(call, data)
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
@@ -744,8 +739,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
call.result.data = {
"params": [{"task": task, **data} for task, data in tasks_params.items()]
@@ -754,39 +748,36 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
force=request.force,
)
}
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
force=request.force,
)
}
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
)
}
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
)
}
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
call.result.data = {
"configurations": [
@@ -801,10 +792,9 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
call.result.data = {
"configurations": [
@@ -815,31 +805,29 @@ def get_configuration_names(
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
force=request.force,
)
}
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
force=request.force,
)
}
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
def delete_configuration(
call: APICall, company_id, request: DeleteConfigurationRequest
):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
)
}
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
)
}
@endpoint(
@@ -854,6 +842,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
force=request.force,
)
call.result.data_model = EnqueueResponse(queued=queued, **res)
@@ -1169,15 +1158,14 @@ def ping(_, company_id, request: PingRequest):
def add_or_update_artifacts(
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
):
with translate_errors_context():
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
)
}
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
)
}
@endpoint(
@@ -1186,31 +1174,28 @@ def add_or_update_artifacts(
request_data_model=DeleteArtifactsRequest,
)
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
with translate_errors_context():
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
)
}
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
)
}
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)
@endpoint("tasks.move", request_data_model=MoveRequest)

View File

@@ -0,0 +1,54 @@
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
from apiserver.tests.automated import TestService
from apiserver.database.utils import id as db_id
class TestProjectsDelete(TestService):
def setUp(self, version="2.14"):
super().setUp(version=version)
def new_task(self, **kwargs):
return self.create_temp(
"tasks", type="testing", name=db_id(), input=dict(view=dict()), **kwargs
)
def new_model(self, **kwargs):
return self.create_temp("models", uri="file:///a/b", name=db_id(), labels={}, **kwargs)
def new_project(self, **kwargs):
return self.create_temp("projects", name=db_id(), description="", **kwargs)
def test_delete_fails_with_active_task(self):
project = self.new_project()
self.new_task(project=project)
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.tasks, 1)
self.assertEqual(res.non_archived_tasks, 1)
with self.api.raises(errors.bad_request.ProjectHasTasks):
self.api.projects.delete(project=project)
def test_delete_with_archived_task(self):
project = self.new_project()
self.new_task(project=project, system_tags=[EntityVisibility.archived.value])
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.tasks, 1)
self.assertEqual(res.non_archived_tasks, 0)
self.api.projects.delete(project=project)
def test_delete_fails_with_active_model(self):
project = self.new_project()
self.new_model(project=project)
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.models, 1)
self.assertEqual(res.non_archived_models, 1)
with self.api.raises(errors.bad_request.ProjectHasModels):
self.api.projects.delete(project=project)
def test_delete_with_archived_model(self):
project = self.new_project()
self.new_model(project=project, system_tags=[EntityVisibility.archived.value])
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.models, 1)
self.assertEqual(res.non_archived_models, 0)
self.api.projects.delete(project=project)

View File

@@ -10,6 +10,7 @@ def extract_properties_to_lists(
key_names: Sequence[str],
data: Sequence[dict],
extract_func: Optional[Callable[[dict], Tuple]] = None,
target_keys: Optional[Sequence[str]] = None,
) -> dict:
"""
Given a list of dictionaries and names of dictionary keys
@@ -20,9 +21,10 @@ def extract_properties_to_lists(
:param extract_func: the optional callable that extracts properties
from a dictionary and put them in a tuple in the order corresponding to
key_names. If not specified then properties are extracted according to key_names
:param target_keys: optional alternative keys to use in the target dictionary. must be equal in length to key_names.
"""
if not data:
return {k: [] for k in key_names}
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
return dict(zip(key_names, map(list, value_sequences)))
return dict(zip((target_keys or key_names), map(list, value_sequences)))

View File

@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.1.0"

View File

@@ -1,8 +1,10 @@
auth {
# Fixed users login credentials
# No other user will be able to login
# Note: password may be bcrypt-hashed (generate using `python3 -c 'import bcrypt,base64; print(base64.b64encode(bcrypt.hashpw("password".encode(), bcrypt.gensalt())))'`)
fixed_users {
enabled: true
pass_hashed: false
users: [
{
username: "jane"

View File

@@ -84,7 +84,7 @@ class BasicConfig:
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
]
if invalid:
print(f"WARNING: Invalid paths in {self.extra_config_path_env_key} env var: {' '.join(invalid)}")
print(f"WARNING: Invalid paths in {self.extra_config_path_env_key} env var: {' '.join(map(str,invalid))}")
return [path for path in paths if path.is_dir()]
def _load(self, verbose=True):

View File

@@ -5,10 +5,11 @@ from argparse import ArgumentParser
from pathlib import Path
from boltons.iterutils import first
from flask import Flask, request, send_from_directory, safe_join, abort, Response
from flask._compat import fspath
from flask import Flask, request, send_from_directory, abort, Response
from flask_compress import Compress
from flask_cors import CORS
from werkzeug.exceptions import NotFound
from werkzeug.security import safe_join
from config import config
@@ -34,7 +35,10 @@ def upload():
if not filename:
continue
file_path = filename.lstrip(os.sep)
target = Path(safe_join(app.config["UPLOAD_FOLDER"], file_path))
safe_path = safe_join(app.config["UPLOAD_FOLDER"], file_path)
if safe_path is None:
raise NotFound()
target = Path(safe_path)
target.parent.mkdir(parents=True, exist_ok=True)
file.save(str(target))
results.append(file_path)
@@ -61,8 +65,8 @@ def download(path):
def delete(path):
real_path = Path(
safe_join(
fspath(app.config["UPLOAD_FOLDER"]),
fspath(path)
os.fspath(app.config["UPLOAD_FOLDER"]),
os.fspath(path)
)
)
if not real_path.exists() or not real_path.is_file():