Compare commits

48 Commits

Author SHA1 Message Date
allegroai
7558426bc6 Fix max upload size limit 2024-06-26 11:21:53 +03:00
allegroai
ce01e37c66 Refactor docker compose files: remove legacy, add services agent initialization in Linux 2024-06-26 10:53:43 +03:00
allegroai
92b42d66b7 Remove default credentials and reset existing credentials if none were provided 2024-06-26 10:52:42 +03:00
allegroai
f7d36bea4f Use an auth token in async_urls_delete when contacting the fileserver 2024-06-20 18:00:19 +03:00
allegroai
f1c876089b Add worker_pattern parameter to workers.get_all and get_count endpoints 2024-06-20 17:59:28 +03:00
allegroai
dd0ecb712d Added fileserver.upload.max_upload_size_mb setting 2024-06-20 17:58:33 +03:00
allegroai
fcfc1e8998 Support a more granular distributed lock wait 2024-06-20 17:57:54 +03:00
allegroai
9c210bb4fa Fix fixed users creation/removal 2024-06-20 17:57:23 +03:00
allegroai
14547155cb Delete pipeline steps in pipelines.delete_runs 2024-06-20 17:55:52 +03:00
allegroai
3f34f83a91 Version bump to 1.16.0
API version bump to 2.30
Add missing endpoints to schema
2024-06-20 17:55:17 +03:00
allegroai
da3941e6f2 Upgrade pymongo dependency 2024-06-20 17:53:15 +03:00
allegroai
2e19a18ee4 Support automatic handling of pipeline steps if a pipeline controller task ID was passed to one of the tasks endpoints 2024-06-20 17:52:46 +03:00
allegroai
cdc668e3c8 Fileserver authorization is enabled by default 2024-06-20 17:50:02 +03:00
allegroai
7c9889605a Add token authorization to fileserver 2024-06-20 17:48:54 +03:00
allegroai
5456ee4ebf Data tool export projects by name now includes subprojects + option for exporting all projects added 2024-06-20 17:48:18 +03:00
allegroai
562cb77003 Support getting and clearing task logs using specific metrics 2024-06-20 17:47:39 +03:00
allegroai
91df2bb3b7 Use better token generation for the secret key 2024-06-20 17:46:23 +03:00
allegroai
cb9812caee Do not return any mongodb instructions as a result of task update operations 2024-06-20 17:44:17 +03:00
allegroai
0496582d96 Ensure min interval on workers history charts so that we do not get "saw like" chart due to the missing points in the intervals 2024-06-20 17:43:52 +03:00
allegroai
beff19e104 Fix do not return full file path on errors from the fileserver 2024-06-20 17:43:19 +03:00
pollfly
639b3d59a4 Update docstrings (#246)
Edit description so they can be rendered using MDX
2024-06-20 17:00:31 +03:00
allegroai
c0d687e2ef Fix missing git in Dockerfile for building webapp 2024-03-28 17:50:35 +02:00
allegroai
9c95c63ce0 Version bump to v1.15.0 2024-03-24 11:25:05 +02:00
allegroai
73179f53c2 Use latest patch versions for ES and Mongo 2024-03-24 11:24:51 +02:00
allegroai
ddc8a76279 Set API version to v2.29 2024-03-18 16:02:45 +02:00
allegroai
ac7ea0d477 Allow filtering task models.input.model field by array of ids 2024-03-18 16:01:45 +02:00
allegroai
3544ed19f8 Use latest patch versions for mongodb and ES 2024-03-18 15:59:15 +02:00
allegroai
5e68f053a0 Add widgets link in nginx configuration 2024-03-18 15:58:19 +02:00
allegroai
7bd5fdad59 Update webserver build: allow using external configuration from a file or from environment variables 2024-03-18 15:57:19 +02:00
allegroai
484c72aa0c Upgrade to Debian bookworm 2024-03-18 15:56:18 +02:00
allegroai
2027afbed5 Added missing ES index template for scalar events 2024-03-18 15:54:38 +02:00
allegroai
7d649f1964 Support controlling config value inheritance from the base folder 2024-03-18 15:53:58 +02:00
allegroai
8d237b3cae Upgrade Redis to v6.2 2024-03-18 15:53:07 +02:00
allegroai
e8ee6ce72e Code cleanup 2024-03-18 15:52:22 +02:00
allegroai
5749ff0454 Add security headers to webserver 2024-03-18 15:50:40 +02:00
allegroai
5189adf4f1 Improve handling of fixed users 2024-03-18 15:49:42 +02:00
allegroai
92a4e56c1f Add support for cookies extensions 2024-03-18 15:46:07 +02:00
allegroai
33528870ae Request cookies processing enhanced for more flexibility 2024-03-18 15:45:09 +02:00
allegroai
85f5b8b6f6 Fix last metrics for task are updated for events reported without variants 2024-03-18 15:44:28 +02:00
allegroai
6112910768 Make sure that legacy templates are deleted and empty db check is done on the new templates 2024-03-18 15:40:13 +02:00
allegroai
d3013ac285 Invalidate token on user logoff 2024-03-18 15:38:44 +02:00
allegroai
88abf28287 Add ElasticSearch 8.x support 2024-03-18 15:37:44 +02:00
allegroai
6a1fc04d1e Set cookies SameSite value to Lax 2024-02-13 16:18:21 +02:00
allegroai
ee8eb03698 Fix crash when importing events for public company tasks 2024-02-13 16:17:52 +02:00
allegroai
5799baae45 Make sure that APIs that aggregate task/model data from projects can be called for the root project 2024-02-13 16:17:33 +02:00
allegroai
801e536c5e Fix tasks.started to correctly handle null values in the started field 2024-02-13 16:17:02 +02:00
allegroai
6e484ea8f4 Fix missing region parameter when deleting files from minio server 2024-02-13 16:16:24 +02:00
allegroai
a47e65d974 Add input parameters check to multiple APIs 2024-02-13 16:15:55 +02:00
88 changed files with 1826 additions and 845 deletions

View File

@@ -13,6 +13,14 @@ from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum
class TaskRequest(Base):
task: str = StringField(required=True)
class ModelRequest(Base):
model: str = StringField(required=True)
class HistogramRequestBase(Base):
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
model_events: bool = BoolField(default=False)
class GetMetricsAndVariantsRequest(Base):
task: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(
items_types=str,
@@ -51,6 +64,12 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class LegacyMetricEventsRequest(TaskRequest):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
@@ -59,7 +78,14 @@ class MetricEventsRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField()
model_events: bool = BoolField(default=False)
class VectorMetricsIterHistogramRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class GetVariantSampleRequest(Base):
@@ -110,11 +136,17 @@ class TaskEventsRequest(TaskEventsRequestBase):
model_events: bool = BoolField(default=False)
class LegacyLogEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
scroll_id: str = StringField()
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
@@ -160,6 +192,11 @@ class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
@@ -177,6 +214,14 @@ class TaskPlotsRequest(Base):
model_events: bool = BoolField(default=False)
class GetScalarMetricDataRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):
scroll_id: str = StringField()
@@ -185,3 +230,5 @@ class ClearTaskLogRequest(Base):
task: str = StringField(required=True)
threshold_sec = IntField()
allow_locked = BoolField(default=False)
exclude_metrics = ListField(items_types=[str])
include_metrics = ListField(items_types=[str])

View File

@@ -42,6 +42,21 @@ class ModelRequest(models.Base):
model = fields.StringField(required=True)
class TaskRequest(models.Base):
task = fields.StringField(required=True)
class UpdateForTaskRequest(TaskRequest):
uri = fields.StringField()
iteration = fields.IntField()
override_model_id = fields.StringField()
class UpdateModelRequest(ModelRequest):
task = fields.StringField()
iteration = fields.IntField()
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)

View File

@@ -46,7 +46,7 @@ class ProjectTagsRequest(TagsRequest):
class MultiProjectRequest(models.Base):
projects = fields.ListField(str)
projects = fields.ListField(items_types=[str, type(None)])
include_subprojects = fields.BoolField(default=True)

View File

@@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class GetConfigRequest(Base):
path = StringField()
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()

View File

@@ -101,6 +101,10 @@ class DequeueRequest(UpdateRequest):
new_status = StringField()
class StopRequest(UpdateRequest):
include_pipeline_steps = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
@@ -112,6 +116,7 @@ class DeleteRequest(UpdateRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class SetRequirementsRequest(TaskRequest):
@@ -264,6 +269,7 @@ class DeleteConfigurationRequest(TaskUpdateRequest):
class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
include_pipeline_steps = BoolField(default=False)
class ArchiveResponse(models.Base):
@@ -275,8 +281,17 @@ class TaskBatchRequest(BatchRequest):
status_message = StringField(default="")
class ArchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class UnarchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
include_pipeline_steps = BoolField(default=False)
class DequeueManyRequest(TaskBatchRequest):
@@ -297,6 +312,7 @@ class DeleteManyRequest(TaskBatchRequest):
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):

View File

@@ -4,6 +4,10 @@ from jsonmodels.models import Base
from apiserver.apimodels import DictField
class UserRequest(Base):
user = StringField(required=True)
class CreateRequest(Base):
id = StringField(required=True)
name = StringField(required=True)

View File

@@ -100,6 +100,7 @@ class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
system_tags = ListField(str)
worker_pattern = StringField()
class GetAllResponse(Base):

View File

@@ -44,7 +44,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.service_repo.auth import Identity
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
@@ -375,7 +374,7 @@ class EventBLL(object):
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
[{"_index": invalid_iteration_error}],
)
if not added:
@@ -439,10 +438,8 @@ class EventBLL(object):
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
metric = event.get("metric")
variant = event.get("variant")
if not (metric and variant):
return
metric = event.get("metric") or ""
variant = event.get("variant") or ""
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
@@ -487,9 +484,9 @@ class EventBLL(object):
recent than the currently stored event for its metric/event_type combination.
last_events contains [metric_name -> event_type -> event]
"""
metric = event.get("metric")
metric = event.get("metric") or ""
event_type = event.get("type")
if not (metric and event_type):
if not event_type:
return
timestamp = last_events[metric][event_type].get("timestamp", None)
@@ -661,8 +658,8 @@ class EventBLL(object):
Return events and next scroll id from the scrolled query
Release the scroll once it is exhausted
"""
total_events = safe_get(es_res, "hits/total/value", default=0)
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
total_events = nested_get(es_res, ("hits", "total", "value"), default=0)
events = [doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.clear_scroll(next_scroll_id)
@@ -1230,6 +1227,8 @@ class EventBLL(object):
task_id: str,
allow_locked: bool = False,
threshold_sec: int = None,
include_metrics: Sequence[str] = None,
exclude_metrics: Sequence[str] = None,
):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
@@ -1254,8 +1253,16 @@ class EventBLL(object):
}
)
sort = {"timestamp": {"order": "desc"}}
if include_metrics:
must.append({"terms": {"metric": include_metrics}})
more_conditions = {}
if exclude_metrics:
more_conditions = {"must_not": [{"terms": {"metric": exclude_metrics}}]}
es_req = {
"query": {"bool": {"must": must}},
"query": {"bool": {"must": must, **more_conditions}},
**({"sort": sort} if sort else {}),
}
es_res = delete_company_events(

View File

@@ -9,7 +9,7 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
class EventType(Enum):
@@ -123,8 +123,8 @@ def get_max_metric_and_variant_counts(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)
metrics_count = safe_get(
es_res, "aggregations/metrics_count/value", max_metrics_count
metrics_count = nested_get(
es_res, ("aggregations", "metrics_count", "value"), max_metrics_count
)
if not metrics_count:
return max_metrics_count, max_variants_count

View File

@@ -24,7 +24,7 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
log = config.logger(__file__)
@@ -342,12 +342,12 @@ class EventMetrics:
total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
"""
count = safe_get(data, "count/value", default=0)
count = nested_get(data, ("count", "value"), default=0)
if count < samples:
return metric, variant, 1, count
min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(data, "max_index/value", default=min_index)
min_index = nested_get(data, ("min_index", "value"), default=0)
max_index = nested_get(data, ("max_index", "value"), default=min_index)
index_range = max_index - min_index + 1
interval = max(1, math.ceil(float(index_range) / samples))
max_samples = math.ceil(float(index_range) / interval)
@@ -592,5 +592,5 @@ class EventMetrics:
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"), default=[])
]

View File

@@ -6,7 +6,6 @@ from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Callable
import attr
import dpath
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
@@ -27,6 +26,7 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
@@ -305,13 +305,13 @@ class MetricEventsIterator:
return [
MetricState(
metric=metric["key"],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
timestamp=nested_get(metric, ("last_event_timestamp", "value")),
variants=[
init_variant_state(variant)
for variant in dpath.get(metric, "variants/buckets")
for variant in nested_get(metric, ("variants", "buckets"))
],
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"))
]
@abc.abstractmethod
@@ -430,14 +430,14 @@ class MetricEventsIterator:
def get_iteration_events(it_: dict) -> Sequence:
return [
self._process_event(ev["_source"])
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")
for m in nested_get(it_, ("metrics", "buckets"))
for v in nested_get(m, ("variants", "buckets"))
for ev in nested_get(v, ("events", "hits", "hits"))
if is_valid_event(ev["_source"])
]
iterations = []
for it in dpath.get(es_res, "aggregations/iters/buckets"):
for it in nested_get(es_res, ("aggregations", "iters", "buckets")):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})

View File

@@ -869,7 +869,7 @@ class ProjectBLL:
company,
project_ids: Sequence[str],
user_ids: Optional[Sequence[str]] = None,
) -> Set[str]:
) -> Set[Union[str, type(None)]]:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined

View File

@@ -18,7 +18,7 @@ from apiserver.config.info import get_deployment_type
from apiserver.database.model import Company, User
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import dumps
from apiserver.version import __version__ as current_version
from .resource_monitor import ResourceMonitor, stat_threads
@@ -162,7 +162,7 @@ class StatisticsReporter:
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
names = {"cpu": "num_cores"}
return {
names[c["key"]]: safe_get(c, "count/value")
names[c["key"]]: nested_get(c, ("count", "value"))
for c in categories
if c["key"] in names
}
@@ -175,21 +175,21 @@ class StatisticsReporter:
}
return {
names[m["key"]]: {
"min": safe_get(m, "min/value"),
"max": safe_get(m, "max/value"),
"avg": safe_get(m, "avg/value"),
"min": nested_get(m, ("min", "value")),
"max": nested_get(m, ("max", "value")),
"avg": nested_get(m, ("avg", "value")),
}
for m in metrics
if m["key"] in names
}
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {
key: {
"interval_sec": agent_resource_threshold_sec,
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
**_get_cardinality_fields(nested_get(b, ("categories", "buckets"), [])),
**_get_metric_fields(nested_get(b, ("metrics", "buckets"), [])),
}
}
for b in buckets
@@ -227,7 +227,7 @@ class StatisticsReporter:
},
}
res = cls._run_worker_stats_query(company_id, es_req)
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
for b in buckets

View File

@@ -22,7 +22,7 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
DEFAULT_LAST_ITERATION, TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
@@ -32,54 +32,79 @@ log = config.logger(__file__)
queue_bll = QueueBLL()
def _get_pipeline_steps_for_controller_task(
task: Task, company_id: str, only: Sequence[str] = None
) -> Sequence[Task]:
if not task or task.type != TaskType.controller:
return []
query = Task.objects(company=company_id, parent=task.id)
if only:
query = query.only(*only)
return list(query)
def archive_task(
task: Union[str, Task],
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
user_id = identity.user
fields = (
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
"type",
)
if isinstance(task, str):
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
only=fields,
)
user_id = identity.user
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
def archive_task_core(task_: Task) -> int:
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task_.update(
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
archive_task_core(step)
return archive_task_core(task)
def unarchive_task(
@@ -88,24 +113,36 @@ def unarchive_task(
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
fields = ("id", "type")
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=("id",),
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
only=fields,
)
def unarchive_task_core(task_: Task) -> int:
return task_.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
unarchive_task_core(step)
return unarchive_task_core(task)
def dequeue_task(
task_id: str,
@@ -262,6 +299,7 @@ def delete_task(
status_message: str,
status_reason: str,
delete_external_artifacts: bool,
include_pipeline_steps: bool,
) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user
task = get_task_with_write_access(
@@ -280,36 +318,51 @@ def delete_task(
current=task.status,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
def delete_task_core(task_: Task, force_: bool):
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
res = cleanup_task(
company=company_id,
user=user_id,
task=task_,
force=force_,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task(
company=company_id,
user=user_id,
task=task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task_.last_update = datetime.utcnow()
task_.save()
else:
task_.delete()
return res
task_ids = [task.id]
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.last_update = datetime.utcnow()
task.save()
else:
task.delete()
move_tasks_to_trash(task_ids)
update_project_time(task.project)
return 1, task, cleanup_res
@@ -465,6 +518,7 @@ def stop_task(
user_name: str,
status_reason: str,
force: bool,
include_pipeline_steps: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
@@ -475,19 +529,21 @@ def stop_task(
:return: updated task fields
"""
user_id = identity.user
fields = (
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
"type",
)
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
only=fields,
)
def is_run_by_worker(t: Task) -> bool:
@@ -499,32 +555,41 @@ def stop_task(
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
is_queued = task.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task.system_tags
or not is_run_by_worker(task)
)
def stop_task_core(task_: Task, force_: bool):
is_queued = task_.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task_.system_tags
or not is_run_by_worker(task_)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task_.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute()
return ChangeStatusRequest(
task=task_,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force_,
user_id=user_id,
).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
stop_task_core(step, True)
return stop_task_core(task, force)

View File

@@ -4,6 +4,7 @@ from typing import Sequence
import attr
import six
from mongoengine import Q
from mongoengine.base import UPDATE_OPERATORS
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
@@ -78,8 +79,16 @@ class ChangeStatusRequest(object):
update_project_time(project_id)
# make sure that _raw_ queries are not returned back to the client
fields.pop("__raw__", None)
def is_mongo_operator(field: str) -> bool:
head, _, tail = field.partition("__")
return tail and (head in UPDATE_OPERATORS)
# make sure to not return _raw_ queries or any of the update operators
fields = {
key: value
for key, value in fields.items()
if not (key == "__raw__" or is_mongo_operator(key))
}
return dict(updated=updated, fields=fields)
@@ -182,6 +191,9 @@ def get_many_tasks_for_writing(
)
)
if not company_id:
return result
forbidden_tasks = {task.id for task in result if not task.company}
if forbidden_tasks:
if throw_on_forbidden:

View File

@@ -1,4 +1,5 @@
import itertools
import re
from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional
@@ -27,14 +28,15 @@ from apiserver.database.model.project import Project
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from .stats import WorkerStats
log = config.logger(__file__)
class WorkerBLL:
_key_regex_trans = str.maketrans({"*": ".*", "?": ".?"})
def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
@@ -208,15 +210,25 @@ class WorkerBLL:
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
):
if not last_seen:
return len(
self._get_keys(company_id, user_tags=tags, system_tags=system_tags)
self._get_keys(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
return len(
self.get_all(
company_id, last_seen=last_seen, tags=tags, system_tags=system_tags
company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
@@ -226,6 +238,7 @@ class WorkerBLL:
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@@ -234,7 +247,12 @@ class WorkerBLL:
:return:
"""
try:
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
workers = self._get(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
@@ -254,6 +272,7 @@ class WorkerBLL:
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerResponseEntry]:
helpers = [
WorkerConversionHelper.from_worker_entry(entry)
@@ -262,6 +281,7 @@ class WorkerBLL:
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
]
@@ -287,7 +307,7 @@ class WorkerBLL:
filter(
None,
(
safe_get(info, "next_entry/task")
nested_get(info, ("next_entry", "task"))
for info in queues_info.values()
),
)
@@ -311,7 +331,7 @@ class WorkerBLL:
continue
entry.name = info.get("name", None)
entry.num_tasks = info.get("num_entries", 0)
task_id = safe_get(info, "next_entry/task")
task_id = nested_get(info, ("next_entry", "task"))
if task_id:
task = tasks_info.get(task_id, None)
entry.next_task = IdNameEntry(
@@ -321,7 +341,7 @@ class WorkerBLL:
for helper in helpers:
worker = helper.worker
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
task: Task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration
@@ -417,16 +437,25 @@ class WorkerBLL:
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[bytes]:
if not (user_tags or system_tags):
match = self._get_worker_key(company, user, "*")
match = self._get_worker_key(company, user, worker_pattern or "*")
return list(self.redis.scan_iter(match))
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
return in_keys
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
def filter_by_user_and_pattern(in_keys: Set[bytes]) -> Set[bytes]:
if user != "*":
user_bytes = user.encode()
in_keys = {k for k in in_keys if user_bytes in k}
if worker_pattern:
worker_pattern_bytes = (
f"{worker_pattern.translate(self._key_regex_trans)}$".encode()
)
regex = re.compile(worker_pattern_bytes)
in_keys = {k for k in in_keys if regex.search(k)}
return in_keys
worker_keys = set()
for tags, tags_field in (
@@ -449,7 +478,7 @@ class WorkerBLL:
)
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
tagged_workers = filter_by_user(tagged_workers)
tagged_workers = filter_by_user_and_pattern(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
@@ -463,7 +492,7 @@ class WorkerBLL:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
worker_keys = filter_by_user_and_pattern(worker_keys)
if not worker_keys:
return []
@@ -488,13 +517,18 @@ class WorkerBLL:
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
entries = []
for keys in chunked_iter(
self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
company,
user=user,
user_tags=user_tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
),
1000,
):

View File

@@ -13,6 +13,8 @@ log = config.logger(__file__)
class WorkerStats:
min_chart_interval = config.get("services.workers.min_chart_interval_sec", 40)
def __init__(self, es):
self.es = es
@@ -203,6 +205,7 @@ class WorkerStats:
"""
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
interval = max(interval, self.min_chart_interval)
must = [QueryBuilder.dates_range(from_date, to_date)]
if active_only:

View File

@@ -6,7 +6,7 @@ from functools import reduce
from os import getenv
from os.path import expandvars
from pathlib import Path
from typing import List, Any, TypeVar, Sequence
from typing import List, Any, TypeVar, Sequence, Set
from boltons.iterutils import first
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
@@ -35,6 +35,7 @@ class BasicConfig:
folder: str = None,
verbose: bool = True,
prefix: Sequence[str] = DEFAULT_PREFIXES,
exclude_files_from_base_folder: Sequence[str] = None,
):
folder = (
Path(folder)
@@ -44,6 +45,11 @@ class BasicConfig:
if not folder.is_dir():
raise ValueError("Invalid configuration folder")
self.exclude_files_from_base_folder = (
set(exclude_files_from_base_folder)
if exclude_files_from_base_folder
else set()
)
self.verbose = verbose
self.extra_config_path_override_var = [
@@ -85,7 +91,7 @@ class BasicConfig:
return logging.getLogger(path)
def _read_extra_env_config_values(self) -> ConfigTree:
""" Loads extra configuration from environment-injected values """
"""Loads extra configuration from environment-injected values"""
result = ConfigTree()
for prefix in self.extra_config_values_env_key_prefix:
@@ -125,12 +131,18 @@ class BasicConfig:
def _reload(self) -> ConfigTree:
extra_config_values = self._read_extra_env_config_values()
configs = [self._read_recursive(path) for path in self._paths]
configs = [
self._read_recursive(
path,
exclude_files=(
self.exclude_files_from_base_folder if idx == 0 else None
),
)
for idx, path in enumerate(self._paths)
]
return reduce(
lambda last, config: self._merge_configs(
last, config, copy_trees=True
),
lambda last, config: self._merge_configs(last, config, copy_trees=True),
configs + [extra_config_values],
ConfigTree(),
)
@@ -141,9 +153,14 @@ class BasicConfig:
for key, value in b.items():
override = key.startswith(override_prefix)
if override:
key = key[len(override_prefix):]
key = key[len(override_prefix) :]
# if key is in both a and b and both values are dictionary then merge it otherwise override it
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
if (
not override
and key in a
and isinstance(a[key], ConfigTree)
and isinstance(b[key], ConfigTree)
):
if copy_trees:
a[key] = a[key].copy()
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
@@ -156,13 +173,15 @@ class BasicConfig:
a[key] = value
if a.root:
if b.root:
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
a.history[key] = a.history.get(key, []) + b.history.get(
key, [value]
)
else:
a.history[key] = a.history.get(key, []) + [value]
return a
def _read_recursive(self, conf_root) -> ConfigTree:
def _read_recursive(self, conf_root, exclude_files: Set[str]) -> ConfigTree:
conf = ConfigTree()
if not conf_root:
@@ -180,6 +199,8 @@ class BasicConfig:
print(f"Loading config from {conf_root}")
for file in conf_root.rglob("*.conf"):
if exclude_files and file.name in exclude_files:
continue
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
conf.put(key, self._read_single_file(file))

View File

@@ -58,6 +58,9 @@
# verify user tokens
verify_user_tokens: false
# If set then users that were created from secure credentials or fixed user settings and are no longer in these settings will be deleted on startup
delete_missing_autocreated_users: true
# max token expiration timeout in seconds (1 year)
max_expiration_sec: 31536000
@@ -72,6 +75,7 @@
httponly: true # allow only http to access the cookies (no JS etc)
secure: false # not using HTTPS
domain: null # Limit to localhost is not supported
samesite: Lax
max_age: 99999999999
}

View File

@@ -2,10 +2,9 @@ fileserver = "http://localhost:8081"
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
hosts: [{host: "127.0.0.1", port: 9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}
@@ -13,10 +12,9 @@ elastic {
}
workers {
hosts: [{host:"127.0.0.1", port:9200}]
hosts: [{host:"127.0.0.1", port:9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}

View File

@@ -1,13 +1,13 @@
{
http {
session_secret {
apiserver: "Gx*gB-L2U8!Naqzd#8=7A4&+=In4H(da424H33ZTDQRGF6=FWw"
apiserver: "V8gcW3EneNDcNfO7G_TSUsWe7uLozyacc9_I33o7bxUo8rCN31VLRg"
}
}
auth {
# token sign secret
token_secret: "7E1ua3xP9GT2(cIQOfhjp+gwN6spBeCAmN-XuugYle00I=Wc+u"
token_secret: "Rq8FW84sSqVgq7WvBB_4EzNl9y8z8IGiDXX3C345_a5AZfcwZcwCIA"
}
credentials {
@@ -15,24 +15,29 @@
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
user_secret: "gaOfhDX2-bpkeI7-cwEcaMuGijxaG2UG3jbIvg4DxmVGF0LNI7rgvCb1-ne38IlBo1w"
}
fileserver {
role: "system"
user_key: "GSQWPEKSKNKF354LC9V6BHXKTYFD5I"
user_secret: "tuBXcGQBECsEhcNiK2kiWi750z9r8Z85XrQ9V0c24huTuCb2xf2X1nKG"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
user_secret: "XhkH6a6ds9JBnM_MrahYyYdO-wS2bqFSm8gl-V0UZXH26Ydd6Eyi28TeBEoSr6Z3Bes"
revoke_in_fixed_mode: true
}
services_agent {
role: "admin"
user_key: "P4BMJA7RK3TKBXGSY8OAA1FA8TOD11"
user_secret: "9LsgSfa0SYz0zli1_c500ZcLqanre2xkWOpepyt1w-BKK3_DKPHrtoj3JSHvyy8bIi0"
user_key: ""
user_secret: ""
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
user_secret: "LPEJbGJ6bK4tujQcmrD3i1dbMBDdwUwelVa-LG0K0FFmY9bzH_H0Sw"
revoke_in_fixed_mode: true
}
}

View File

@@ -9,4 +9,5 @@ fileserver {
# Can be in the form <schema>://host:port/path or /path
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
timeout_sec: 300
token_expiration_sec: 600
}

View File

@@ -18,8 +18,9 @@ aws {
{
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
host: "localhost:9000"
key: "evg_user"
secret: "evg_pass"
key: "minioadmin"
secret: "minioadmin"
# region: my-server
multipart: false
secure: false
}

View File

@@ -0,0 +1,5 @@
default_worker_timeout_sec: 600
default_cluster_timeout_sec: 600
# The minimal sampling interval for resource dashboard and worker activity charts
min_chart_interval_sec: 40

View File

@@ -5,7 +5,7 @@ from textwrap import shorten
import dpath
from dpath.exceptions import InvalidKeyName
from elasticsearch import ElasticsearchException
from elastic_transport import TransportError, ApiError
from elasticsearch.helpers import BulkIndexError
from jsonmodels.errors import ValidationError as JsonschemaValidationError
from mongoengine.errors import (
@@ -210,9 +210,9 @@ def translate_errors_context(message=None, **kwargs):
raise errors.bad_request.ValidationError(e.args[0])
except BulkIndexError as e:
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
except ElasticsearchException as e:
except (TransportError, ApiError) as e:
raise errors.server_error.DataError(e, message, **kwargs)
except InvalidKeyName:
raise errors.server_error.DataError("invalid empty key encountered in data")
except Exception as ex:
except Exception:
raise

View File

@@ -4,6 +4,7 @@ from mongoengine import (
EmbeddedDocumentListField,
EmailField,
DateTimeField,
BooleanField,
)
from apiserver.database import Database, strict
@@ -76,3 +77,6 @@ class User(DbModelMixin, AuthDocument):
email = EmailField(unique=True, sparse=True)
""" Email uniquely identifying the user """
autocreated = BooleanField(default=False)
""" Set to true if the user was auto created based on config settings"""

View File

@@ -1297,7 +1297,6 @@ class GetMixin(PropsMixin):
return result
class UpdateMixin(object):
__user_set_allowed_fields = None
__locked_when_published_fields = None

View File

@@ -231,11 +231,12 @@ class Task(AttributedDocument):
"parent",
"hyperparams.*",
"execution.queue",
"models.input.model",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment", "report"),
fields=("runtime.*", "models.input.model"),
fields=("runtime.*",),
)
id = StringField(primary_key=True)

View File

@@ -2,6 +2,9 @@
| Release | ApiVersion |
|---------|------------|
| v1.16 | 2.30 |
| v1.15 | 2.29 |
| v1.14 | 2.28 |
| v1.13 | 2.27 |
| v1.12 | 2.26 |
| v1.11 | 2.25 |

View File

@@ -4,34 +4,89 @@ Apply elasticsearch mappings to given hosts.
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional, Sequence, Tuple
from elasticsearch import Elasticsearch
from elasticsearch import Elasticsearch, exceptions
HERE = Path(__file__).resolve().parent
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
hosts: Sequence,
key: Optional[str] = None,
es_args: dict = None,
http_auth: Tuple = None,
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
def _send_template(f):
with f.open() as json_data:
data = json.load(json_data)
template_name = f.stem
res = es.indices.put_template(name=template_name, body=data)
return {"mapping": template_name, "result": res}
def _send_component_template(ct_file):
with ct_file.open() as json_data:
body = json.load(json_data)
template_name = f"{ct_file.stem}"
res = es.cluster.put_component_template(name=template_name, body=body)
return {"component_template": template_name, "result": res}
p = HERE / "mappings"
if key:
files = (p / key).glob("*.json")
else:
files = p.glob("**/*.json")
def _send_index_template(it_file):
with it_file.open() as json_data:
body = json.load(json_data)
template_name = f"{it_file.stem}"
res = es.indices.put_index_template(name=template_name, body=body)
return {"index_template": template_name, "result": res}
# def _send_legacy_template(f):
# with f.open() as json_data:
# data = json.load(json_data)
# template_name = f.stem
# res = es.indices.put_template(name=template_name, body=data)
# return {"mapping": template_name, "result": res}
def _delete_legacy_templates(legacy_folder):
res_list = []
for lt in legacy_folder.glob("*.json"):
template_name = lt.stem
try:
if not es.indices.get_template(name=template_name):
continue
res = es.indices.delete_template(name=template_name)
except exceptions.NotFoundError:
continue
res_list.append({"deleted legacy mapping": template_name, "result": res})
return res_list
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
return [_send_template(f) for f in files]
root = HERE / "index_templates"
if key:
folders = [root / key]
else:
folders = [f for f in root.iterdir() if f.is_dir()]
ret = []
for f in folders:
for ct in (f / "component_templates").glob("*.json"):
ret.append(_send_component_template(ct))
for it in f.glob("*.json"):
ret.append(_send_index_template(it))
legacy_root = HERE / "mappings"
for f in folders:
legacy_f = legacy_root / f.stem
if not legacy_f.exists() or not legacy_f.is_dir():
continue
ret.extend(_delete_legacy_templates(legacy_f))
return ret
# p = HERE / "mappings"
# if key:
# files = (p / key).glob("*.json")
# else:
# files = p.glob("**/*.json")
#
# return [_send_template(f) for f in files]
def parse_args():

View File

@@ -0,0 +1,48 @@
{
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"@timestamp": {
"type": "date"
},
"task": {
"type": "keyword"
},
"type": {
"type": "keyword"
},
"worker": {
"type": "keyword"
},
"timestamp": {
"type": "date"
},
"iter": {
"type": "long"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"company_id": {
"type": "keyword"
},
"model_event": {
"type": "boolean"
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
{
"index_patterns": "events-log-*",
"template": {
"mappings": {
"properties": {
"msg": {
"type": "text",
"index": false
},
"level": {
"type": "keyword"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,18 @@
{
"index_patterns": "events-plot-*",
"template": {
"mappings": {
"properties": {
"plot_str": {
"type": "text",
"index": false
},
"plot_data": {
"type": "binary"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,17 @@
{
"index_patterns": "events-training_debug_image-*",
"template": {
"mappings": {
"properties": {
"key": {
"type": "keyword"
},
"url": {
"type": "keyword"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,5 @@
{
"index_patterns": "events-training_stats_scalar-*",
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,31 @@
{
"index_patterns": "queue_metrics_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
},
"company_id": {
"type": "keyword"
}
}
}
}
}

View File

@@ -0,0 +1,43 @@
{
"index_patterns": "worker_stats_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"worker": {
"type": "keyword"
},
"category": {
"type": "keyword"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"unit": {
"type": "keyword"
},
"task": {
"type": "keyword"
},
"company_id": {
"type": "keyword"
}
}
}
}
}

View File

@@ -10,6 +10,8 @@ from apiserver.config_repo import config
from apiserver.elastic.apply_mappings import apply_mappings_to_cluster
log = config.logger(__file__)
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
class MissingElasticConfiguration(Exception):
@@ -78,6 +80,18 @@ def check_elastic_empty() -> bool:
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
)
def events_legacy_template():
try:
return es.indices.get_template(name="events*")
except exceptions.NotFoundError:
return False
def events_template():
try:
return es.indices.get_index_template(name="events*")
except exceptions.NotFoundError:
return False
try:
es_logger.addFilter(log_filter)
for retry in range(max_retries):
@@ -87,10 +101,7 @@ def check_elastic_empty() -> bool:
http_auth=es_factory.get_credentials("events", cluster_conf),
**cluster_conf.get("args", {}),
)
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
return True
return not (events_template() or events_legacy_template())
except exceptions.ConnectionError as ex:
if retry >= max_retries - 1:
raise ElasticConnectionError(

View File

@@ -1,3 +1,4 @@
import logging
from datetime import datetime
from functools import lru_cache
from os import getenv
@@ -9,6 +10,8 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
log = config.logger(__file__)
logging.getLogger('elasticsearch').setLevel(logging.WARNING)
logging.getLogger('elastic_transport').setLevel(logging.WARNING)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_ELASTIC_SERVICE_HOST",
@@ -32,6 +35,7 @@ if OVERRIDE_HOST:
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
OVERRIDE_PORT = int(OVERRIDE_PORT)
log.info(f"Using override elastic port {OVERRIDE_PORT}")
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))

View File

@@ -19,7 +19,9 @@ from google.cloud import storage as google_storage
from mongoengine import Q
from mypy_boto3_s3.service_resource import Bucket as AWSBucket
from apiserver.bll.auth import AuthBLL
from apiserver.bll.storage import StorageBLL
from apiserver.config.info import get_default_company
from apiserver.config_repo import config
from apiserver.database import db
from apiserver.database.model.url_to_delete import UrlToDelete, StorageType, DeletionStatus
@@ -200,6 +202,8 @@ class FileserverStorage(Storage):
res_data = res.json()
return list(res_data.get("deleted", {})), res_data.get("errors", {})
token_expiration_sec = conf.get("fileserver.token_expiration_sec", 600)
def __init__(self, company: str, fileserver_host: str = None):
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
self.host = fileserver_host.rstrip("/")
@@ -220,13 +224,6 @@ class FileserverStorage(Storage):
self.company = company
# @classmethod
# def validate_fileserver_access(cls, fileserver_host: str):
# res = requests.get(
# url=fileserver_host
# )
# res.raise_for_status()
@property
def name(self) -> str:
return "Fileserver"
@@ -260,7 +257,13 @@ class FileserverStorage(Storage):
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
host = base
token = AuthBLL.get_token_for_user(
user_id="__apiserver__",
company_id=get_default_company(),
expiration_sec=self.token_expiration_sec,
).token
session = requests.session()
session.headers.update({"Authorization": "Bearer {}".format(token)})
res = session.get(url=host, timeout=self.Client.timeout)
res.raise_for_status()
@@ -285,6 +288,7 @@ class AzureStorage(Storage):
):
raise ValueError("No path found following container name")
# noinspection PyTypeChecker
return os.path.join(*parsed.path.segments[1:])
@staticmethod
@@ -450,6 +454,7 @@ class AWSStorage(Storage):
else None,
"use_ssl": cfg.secure,
"verify": cfg.verify,
"region_name": cfg.region or None,
}
name = base[len(scheme_prefix(self.scheme)) :]
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name

View File

@@ -3,7 +3,7 @@ from typing import Sequence, Union
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model.auth import Role
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate
@@ -60,14 +60,18 @@ def init_mongo_data():
fixed_mode = FixedUser.enabled()
internal_user_emails = set()
for user, credentials in config.get("secure.credentials", {}).items():
email = f"{user}@example.com"
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"email": email,
"key": credentials.user_key,
"secret": credentials.user_secret,
"autocreated": True,
}
internal_user_emails.add(email.lower())
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
@@ -82,8 +86,20 @@ def init_mongo_data():
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, log=log)
ensure_fixed_user(user, log=log, emails=internal_user_emails)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
if internal_user_emails and config.get(
f"apiserver.auth.delete_missing_autocreated_users", True
):
for user in AuthUser.objects(
company=company_id, autocreated=True, email__nin=internal_user_emails
):
log.info(
f"Removing user that is no longer in configuration: {user['id']}\t{user['email']}\t{user['name']}"
)
user.delete()
except Exception as ex:
log.exception("Failed initializing mongodb")
log.exception(f"Failed initializing mongodb: {str(ex)}")

View File

@@ -22,6 +22,7 @@ from typing import (
Mapping,
IO,
Callable,
Iterable,
)
from urllib.parse import unquote, urlparse
from uuid import uuid4, UUID, uuid5
@@ -220,6 +221,9 @@ class PrePopulate:
raise ValueError("Invalid task statuses")
file = Path(filename)
if not (experiments or projects):
projects = cls.project_cls.objects(parent=None).scalar("id")
entities = cls._resolve_entities(
experiments=experiments, projects=projects, task_statuses=task_statuses
)
@@ -417,24 +421,50 @@ class PrePopulate:
featured_index = get_index(project)
cls.project_cls.objects(id=project.id).update(featured=featured_index)
@staticmethod
def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
@classmethod
def _resolve_entity_type(
cls, entity_type: Type[mongoengine.Document], ids: Optional[Sequence[str]]
) -> Sequence[Any]:
ids = set(ids)
items = list(cls.objects(id__in=list(ids)))
items = list(entity_type.objects(id__in=list(ids)))
resolved = {i.id for i in items}
missing = ids - resolved
for name_candidate in missing:
results = list(cls.objects(name=name_candidate))
if not results:
print(f"ERROR: no match for `{name_candidate}`")
exit(1)
elif len(results) > 1:
print(f"ERROR: more than one match for `{name_candidate}`")
exit(1)
items.append(results[0])
return items
if not missing:
return items
resolved_by_name = defaultdict(list)
for entity in entity_type.objects(name__in=list(missing)):
resolved_by_name[entity.name].append(entity)
not_found = missing - set(resolved_by_name)
if not_found:
print(f"ERROR: no match for {', '.join(not_found)}")
exit(1)
duplicates = [k for k, v in resolved_by_name.items() if len(v) > 1]
if duplicates:
print(f"ERROR: more than one match for {', '.join(duplicates)}")
exit(1)
def get_new_items(input_: Iterable) -> list:
return [item for item in input_ if item.id not in resolved]
def get_projects_with_children(projects: list) -> list:
project_ids = set(item.id for item in projects)
ids_with_children = project_ids_with_children(list(project_ids))
if project_ids == set(ids_with_children):
return projects
return get_new_items(entity_type.objects(id__in=ids_with_children))
new_items = get_new_items(chain(*resolved_by_name.values()))
if not new_items:
return items
if entity_type == cls.project_cls:
new_items = get_projects_with_children(new_items)
return items + new_items
@classmethod
def _check_projects_hierarchy(cls, projects: Set[Project]):
@@ -467,7 +497,7 @@ class PrePopulate:
print("Reading projects...")
projects = project_ids_with_children(projects)
entities[cls.project_cls].update(
cls._resolve_type(cls.project_cls, projects)
cls._resolve_entity_type(cls.project_cls, projects)
)
print("--> Reading project experiments...")
query = Q(
@@ -485,7 +515,7 @@ class PrePopulate:
if experiments:
print("Reading experiments...")
entities[cls.task_cls].update(cls._resolve_type(cls.task_cls, experiments))
entities[cls.task_cls].update(cls._resolve_entity_type(cls.task_cls, experiments))
print("--> Reading experiments projects...")
objs = cls.project_cls.objects(
id__in=list(

View File

@@ -9,34 +9,66 @@ from apiserver.database.model.user import User
from apiserver.service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
key, secret = user_data.get("key"), user_data.get("secret")
def _ensure_user_credentials(
user: AuthUser, key: str, secret: str, log: Logger, revoke: bool = False
) -> None:
if revoke:
log.info(f"Revoking credentials for existing user {user.id} ({user.name})")
user.credentials = []
user.save()
return
if not (key and secret):
credentials = None
else:
creds = Credentials(key=key, secret=secret)
log.info(f"Resetting credentials for existing user {user.id} ({user.name})")
user.credentials = []
user.save()
return
user = AuthUser.objects(credentials__match=creds).first()
if user:
if revoke:
user.credentials = []
user.save()
return user.id
new_credentials = Credentials(key=key, secret=secret)
log.info(f"Setting credentials for existing user {user.id} ({user.name})")
user.credentials = [new_credentials]
user.save()
return
credentials = [] if revoke else [creds]
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False) -> str:
user_id = user_data.get("id", f"__{user_data['name']}__")
role = user_data["role"]
email = user_data["email"]
autocreated = user_data.get("autocreated", False)
key, secret = user_data.get("key"), user_data.get("secret")
user: AuthUser = AuthUser.objects(id=user_id).first()
if user:
_ensure_user_credentials(user=user, key=key, secret=secret, log=log, revoke=revoke)
if (
user.role != role
or user.email != email
or user.autocreated != autocreated
):
user.email = email
user.role = role
user.autocreated = autocreated
user.save()
return user.id
credentials = (
[Credentials(key=key, secret=secret)]
if not revoke and key and secret
else []
)
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_id,
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
role=role,
email=email,
created=datetime.utcnow(),
credentials=credentials,
autocreated=autocreated,
)
user.save()
@@ -59,7 +91,17 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
return user_id
def ensure_fixed_user(user: FixedUser, log: Logger):
def ensure_fixed_user(user: FixedUser, log: Logger, emails: set):
# noinspection PyTypeChecker
data = attr.asdict(user)
data["id"] = user.user_id
email = f"{user.user_id}@example.com"
data["email"] = email
data["role"] = Role.guest if user.is_guest else Role.user
data["autocreated"] = True
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
db_user = User.objects(company=user.company, id=user.user_id).first()
if db_user:
# noinspection PyBroadException
@@ -69,13 +111,7 @@ def ensure_fixed_user(user: FixedUser, log: Logger):
db_user.update(name=user.name, given_name=given_name, family_name=family_name)
except Exception:
pass
return
else:
_ensure_backend_user(user.user_id, user.company, user.name)
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.guest if user.is_guest else Role.user
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
return _ensure_backend_user(user.user_id, user.company, user.name)
emails.add(email)

View File

@@ -6,7 +6,7 @@ boto3>=1.26
boto3-stubs[s3]>=1.26
clearml>=1.10.3
dpath>=1.4.2,<2.0
elasticsearch==7.17.9
elasticsearch==8.12.0
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
@@ -25,7 +25,7 @@ packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35r
pyjwt>=2.4.0
pymongo==4.4.0
pymongo==4.6.3
python-rapidjson>=0.6.3
redis>=4.5.4,<5
requests>=2.13.0

View File

@@ -947,6 +947,13 @@ get_task_log {
}
}
}
"2.30": ${get_task_log."2.9"} {
request.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_events {
"2.1" {
@@ -1705,4 +1712,18 @@ clear_task_log {
}
}
}
"2.30": ${clear_task_log."2.19"} {
request.properties {
include_metrics {
type: array
description: If passed then only events for these metrics are deleted
items: {type: string}
}
exclude_metrics {
type: array
description: If passed then events for these metrics are retained
items: {type: string}
}
}
}
}

View File

@@ -349,7 +349,7 @@ get_all {
items { type: string }
}
last_update {
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string

View File

@@ -446,7 +446,7 @@ get_task_data {
type: string
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string
@@ -656,7 +656,7 @@ get_all_ex {
items { type: string }
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string

View File

@@ -277,7 +277,7 @@ get_all {
type: string
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string
@@ -1107,6 +1107,13 @@ delete_many {
default: true
}
}
"2.30": ${delete_many."2.21"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks the pipeline steps will be also deleted
type: boolean
default: false
}
}
}
delete {
"2.1" {
@@ -1182,6 +1189,13 @@ delete {
default: true
}
}
"2.30": ${delete."2.21"} {
request.properties.include_pipeline_steps {
description: If set then and the passed task is a pipeline controller then delete the pipeline tasks too
type: boolean
default: false
}
}
}
archive {
"2.12" {
@@ -1219,6 +1233,13 @@ archive {
}
}
}
"2.30": ${archive."2.12"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks also archive the pipeline steps
type: boolean
default: false
}
}
}
archive_many {
"2.13": ${_definitions.batch_operation} {
@@ -1245,6 +1266,13 @@ archive_many {
}
}
}
"2.30": ${archive_many."2.13"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks also archive the pipeline steps
type: boolean
default: false
}
}
}
unarchive_many {
"2.13": ${_definitions.batch_operation} {
@@ -1271,6 +1299,13 @@ unarchive_many {
}
}
}
"2.30": ${unarchive_many."2.13"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks also archive the pipeline steps
type: boolean
default: false
}
}
}
started {
"2.1" {
@@ -1309,6 +1344,13 @@ stop {
} ${_references.status_change_request}
response: ${_definitions.update_response}
}
"2.30": ${stop."2.1"} {
request.properties.include_pipeline_steps {
description: If set and the passed task is a pipeline controller then stop all its steps too
type: boolean
default: false
}
}
}
stop_many {
"2.13": ${_definitions.change_many_request} {
@@ -1322,6 +1364,13 @@ stop_many {
}
}
}
"2.30": ${stop_many."2.13"} {
request.properties.include_pipeline_steps {
description: If set then for all the passed pipeline controller tasks stop their steps too
type: boolean
default: false
}
}
}
stopped {
"2.1" {

View File

@@ -310,6 +310,12 @@ get_all {
items { type: string }
}
}
"2.30": ${get_all."2.22"} {
request.properties.worker_pattern {
description: The worker name pattern. If specified then only matching keys returned
type: string
}
}
}
get_count {
"2.26": {
@@ -345,6 +351,12 @@ get_count {
}
}
}
"2.30": ${get_count."2.26"} {
request.properties.worker_pattern {
description: The worker name pattern. If specified then only matching keys are counted
type: string
}
}
}
register {
"2.4" {

View File

@@ -21,6 +21,11 @@ log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
_custom_cookie_settings = {
c["name"]: c["settings"]
for c in config.get("apiserver.auth.custom_cookies", {}).values()
if c.get("enabled") and c.get("settings")
}
def before_request(self):
if request.method == "OPTIONS":
@@ -29,7 +34,10 @@ class RequestHandlers:
return
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
return (
f"Content encoding is not supported ({request.content_encoding})",
415,
)
try:
call = self._create_api_call(request)
@@ -70,7 +78,10 @@ class RequestHandlers:
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies").copy()
kwargs = (
self._custom_cookie_settings.get(key)
or config.get("apiserver.auth.cookies")
).copy()
if value is None:
# Removing a cookie
kwargs["max_age"] = 0
@@ -87,7 +98,9 @@ class RequestHandlers:
if company:
try:
# use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
kwargs["domain"] = config.get(
f"apiserver.auth.cookies_domain_override.{company}"
)
except KeyError:
pass
@@ -114,11 +127,15 @@ class RequestHandlers:
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
v = (
[convert_value(x) for x in v]
if (len(v) > 1 or k.endswith("[]"))
else convert_value(v[0])
)
nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
"""Use request payload/form to fill call data or batched data"""
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
@@ -148,6 +165,9 @@ class RequestHandlers:
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def _get_session_auth_cookie(self, req):
return req.cookies.get(config.get("apiserver.auth.session_auth_cookie_name"))
def _create_api_call(self, req):
call = None
try:
@@ -161,9 +181,7 @@ class RequestHandlers:
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
auth_cookie = self._get_session_auth_cookie(req)
headers = (
{}
if not auth_cookie

View File

@@ -1,4 +1,4 @@
from .auth import get_auth_func, authorize_impersonation
from .auth import get_auth_func, authorize_impersonation, revoke_auth_token
from .payload import Token, Basic, AuthType, Payload
from .identity import Identity
from .utils import get_client_id, get_secret_key

View File

@@ -1,5 +1,6 @@
import base64
from datetime import datetime
from time import time
import bcrypt
import jwt
@@ -11,15 +12,16 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.redis_manager import redman
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
log = config.logger(__file__)
entity_keys = set(get_options(Entities))
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
_revoked_tokens_key = "revoked_tokens"
redis = redman.connection("apiserver")
def get_auth_func(auth_type):
@@ -41,8 +43,10 @@ def authorize_token(jwt_token, service, action, call):
log.error(f"{msg} Call info: {info}")
try:
return Token.from_encoded_token(jwt_token)
token = Token.from_encoded_token(jwt_token)
if is_token_revoked(token):
raise errors.unauthorized.InvalidToken("revoked token")
return token
except jwt.exceptions.InvalidKeyError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken(
@@ -154,3 +158,23 @@ def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
return bcrypt.checkpw(
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
)
def is_token_revoked(token: Token) -> bool:
if not isinstance(token, Token) or not token.session_id:
return False
return redis.zscore(_revoked_tokens_key, token.session_id) is not None
def revoke_auth_token(token: Token):
if not isinstance(token, Token) or not token.session_id:
return
timestamp_now = int(time())
expiration_timestamp = token.exp
if not expiration_timestamp:
expiration_timestamp = timestamp_now + Token.default_expiration_sec
redis.zadd(_revoked_tokens_key, {token.session_id: expiration_timestamp})
redis.zremrangebyscore(_revoked_tokens_key, min=0, max=timestamp_now)

View File

@@ -1,3 +1,5 @@
from uuid import uuid4
import jwt
from datetime import datetime, timedelta
@@ -20,7 +22,15 @@ class Token(Payload):
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
def __init__(
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
self,
exp=None,
iat=None,
nbf=None,
env=None,
identity=None,
session_id=None,
entities=None,
**_,
):
super(Token, self).__init__(
AuthType.bearer_token, identity=identity, entities=entities
@@ -28,8 +38,13 @@ class Token(Payload):
self.exp = exp
self.iat = iat
self.nbf = nbf
self._session_id = session_id
self._env = env or config.get("env", "<unknown>")
@property
def session_id(self):
return self._session_id
@property
def env(self):
return self._env
@@ -102,8 +117,11 @@ class Token(Payload):
expiration_sec = expiration_sec or cls.default_expiration_sec
now = datetime.utcnow()
session_id = uuid4().hex
token = cls(identity=identity, entities=entities, iat=now)
token = cls(
identity=identity, entities=entities, iat=now, session_id=session_id
)
if expiration_sec:
# add 'expiration' claim

View File

@@ -1,40 +1,38 @@
import random
import secrets
import string
sys_random = random.SystemRandom()
def get_random_string(length):
"""
Create a random crypto-safe sequence of 'length' or more characters
Possible characters: alphanumeric, '-' and '_'
Make sure that it starts from alphanumeric for better compatibility with yaml files
"""
token = secrets.token_urlsafe(length)
for _ in range(10):
if not (token.startswith("-") or token.startswith("_")):
break
token = secrets.token_urlsafe(length)
return token
def get_random_string(
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
def get_client_id(
length: int = 30, allowed_chars: str = string.ascii_uppercase + string.digits
) -> str:
"""
Returns a securely generated random string.
The default length of 12 with the a-z, A-Z, 0-9 character set returns
a 71-bit value. log_2((26+26+10)^12) =~ 71 bits.
Taken from the django.utils.crypto module.
Create a random client id composed of 'length' upper case characters or digits
"""
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
def get_client_id(length: int = 20) -> str:
"""
Create a random secret key.
Taken from the Django project.
"""
chars = string.ascii_uppercase + string.digits
return get_random_string(length, chars)
return "".join(secrets.choice(allowed_chars) for _ in range(length))
def get_secret_key(length: int = 50) -> str:
"""
Create a random secret key.
Taken from the Django project.
NOTE: asterisk is not supported due to issues with environment variables containing
asterisks (in case the secret key is stored in an environment variable)
Create a random secret key
"""
chars = string.ascii_letters + string.digits
return get_random_string(length, chars)
return get_random_string(length)
if __name__ == "__main__":
print(get_client_id())
print(get_secret_key())

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.28")
_max_version = PartialVersion("2.30")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -296,7 +296,7 @@ class ServiceRepo(object):
except APIError as ex:
# report stack trace only for gene
include_stack = cls._return_stack and cls._should_return_stack(
include_stack = cls._should_return_stack(
ex.code, ex.subcode
)
call.set_error_result(
@@ -310,8 +310,11 @@ class ServiceRepo(object):
pass
except Exception as ex:
log.exception(ex)
include_stack = cls._should_return_stack(
500, 0
)
call.set_error_result(
code=500, subcode=0, msg=str(ex), include_stack=cls._return_stack
code=500, subcode=0, msg=str(ex), include_stack=include_stack
)
finally:
content, content_type = call.get_response()

View File

@@ -24,6 +24,7 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Role
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Token
from apiserver.service_repo.auth.auth import is_token_revoked, revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
log = config.logger(__file__)
@@ -35,7 +36,7 @@ log = config.logger(__file__)
response_data_model=GetTokenResponse,
)
def login(call: APICall, *_, **__):
""" Generates a token based on the authenticated user (intended for use with credentials) """
"""Generates a token based on the authenticated user (intended for use with credentials)"""
call.result.data_model = AuthBLL.get_token_for_user(
user_id=call.identity.user,
company_id=call.identity.company,
@@ -48,6 +49,7 @@ def login(call: APICall, *_, **__):
@endpoint("auth.logout", min_version="2.2")
def logout(call: APICall, *_, **__):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)
@@ -57,7 +59,7 @@ def logout(call: APICall, *_, **__):
response_data_model=GetTokenResponse,
)
def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
""" Generates a token based on a requested user and company. INTERNAL. """
"""Generates a token based on a requested user and company. INTERNAL."""
if call.identity.role not in Role.get_system_roles():
if call.identity.role != Role.admin and call.identity.user != request.user:
raise errors.bad_request.InvalidUserId(
@@ -81,12 +83,14 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
response_data_model=ValidateResponse,
)
def validate_token_endpoint(call: APICall, _, __):
""" Validate a token and return identity if valid. INTERNAL. """
"""Validate a token and return identity if valid. INTERNAL."""
try:
# if invalid, decoding will fail
token = Token.from_encoded_token(call.data_model.token)
call.result.data_model = ValidateResponse(
valid=True, user=token.identity.user, company=token.identity.company
valid=not is_token_revoked(token),
user=token.identity.user,
company=token.identity.company,
)
except Exception as e:
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
@@ -98,7 +102,7 @@ def validate_token_endpoint(call: APICall, _, __):
response_data_model=CreateUserResponse,
)
def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """
"""Create a user from. INTERNAL."""
if (
call.identity.role not in Role.get_system_roles()
and request.company != call.identity.company

View File

@@ -32,6 +32,14 @@ from apiserver.apimodels.events import (
TaskMetric,
MultiTaskPlotsRequest,
MultiTaskMetricsRequest,
LegacyLogEventsRequest,
TaskRequest,
GetMetricsAndVariantsRequest,
ModelRequest,
LegacyMetricEventsRequest,
GetScalarMetricDataRequest,
VectorMetricsIterHistogramRequest,
LegacyMultiTaskEventsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
@@ -97,15 +105,15 @@ def add_batch(call: APICall, company_id, _):
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log")
def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
order = request.order
scroll_id = request.scroll_id
batch_size = request.batch_size
events, scroll_id, total_events = event_bll.scroll_task_events(
task.get_index_company(),
task_id,
@@ -119,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _):
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log", min_version="1.7")
def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
order = request.order
from_ = call.data.get("from") or "head"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_id = request.scroll_id
batch_size = request.batch_size
scroll_order = "asc" if (from_ == "head") else "desc"
@@ -164,6 +172,7 @@ def get_task_log(call, company_id, request: LogEventsRequest):
batch_size=request.batch_size,
navigate_earlier=request.navigate_earlier,
from_timestamp=request.from_timestamp,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
if request.order and (
@@ -177,9 +186,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.download_task_log")
def download_task_log(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@@ -257,10 +266,12 @@ def download_task_log(call, company_id, _):
call.result.raw_data = generate()
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_vector_metrics_and_variants")
def get_vector_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
@@ -273,10 +284,12 @@ def get_vector_metrics_and_variants(call, company_id, _):
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_scalar_metrics_and_variants")
def get_scalar_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
@@ -292,18 +305,19 @@ def get_scalar_metrics_and_variants(call, company_id, _):
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint(
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
def vector_metrics_iter_histogram(
call, company_id, request: VectorMetricsIterHistogramRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
task_id,
model_events=model_events,
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
metric = request.metric
variant = request.variant
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task_or_model.get_index_company(), task_id, metric, variant
)
@@ -404,13 +418,13 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
)
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
@endpoint("events.get_scalar_metric_data")
def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest):
task_id = request.task
metric = request.metric
scroll_id = request.scroll_id
no_scroll = request.no_scroll
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id,
@@ -435,9 +449,9 @@ def get_scalar_metric_data(call, company_id, _):
)
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_latest_scalar_values")
def get_task_latest_scalar_values(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@@ -558,11 +572,11 @@ def get_task_single_value_metrics(
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_multi_task_plots")
def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest):
task_ids = request.tasks
iters = request.iters
scroll_id = request.scroll_id
companies = _get_task_or_model_index_companies(company_id, task_ids)
@@ -644,11 +658,11 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
)
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_task_plots")
def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -766,11 +780,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
@endpoint("events.debug_images")
def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -803,12 +817,12 @@ def get_debug_images_v1_7(call, company_id, _):
)
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
model_events = call.data.get("model_events", False)
@endpoint("events.debug_images", min_version="1.8")
def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
model_events = request.model_events
tasks_or_model = _assert_task_or_model_exists(
company_id,
@@ -975,8 +989,7 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
return {"metrics": []}
metrics = event_bll.metrics.get_multi_task_metrics(
companies=companies,
event_type=request.event_type
companies=companies, event_type=request.event_type
)
res = [
{
@@ -985,14 +998,12 @@ def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsR
}
for m, vars_ in metrics.items()
]
call.result.data = {
"metrics": sorted(res, key=itemgetter("metric"))
}
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.delete_for_task")
def delete_for_task(call, company_id, request: TaskRequest):
task_id = request.task
allow_locked = call.data.get("allow_locked", False)
get_task_with_write_access(
@@ -1005,9 +1016,9 @@ def delete_for_task(call, company_id, _):
)
@endpoint("events.delete_for_model", required_fields=["model"])
def delete_for_model(call: APICall, company_id: str, _):
model_id = call.data["model"]
@endpoint("events.delete_for_model")
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
model_id = request.model
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)
@@ -1031,6 +1042,8 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest)
task_id=task_id,
allow_locked=request.allow_locked,
threshold_sec=request.threshold_sec,
exclude_metrics=request.exclude_metrics,
include_metrics=request.include_metrics,
)
)

View File

@@ -7,6 +7,7 @@ from apiserver.apimodels.login import (
)
from apiserver.config import info
from apiserver.service_repo import endpoint, APICall
from apiserver.service_repo.auth import revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
@@ -37,4 +38,5 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
@endpoint("login.logout", min_version="2.13")
def logout(call: APICall, _, __):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)

View File

@@ -21,6 +21,10 @@ from apiserver.apimodels.models import (
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
ModelsGetRequest,
ModelRequest,
TaskRequest,
UpdateForTaskRequest,
UpdateModelRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata
@@ -67,9 +71,9 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
unescape_metadata(call, model_data)
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
@endpoint("models.get_by_id")
def get_by_id(call: APICall, company_id, request: ModelRequest):
model_id = request.model
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many(
company=company_id,
@@ -87,12 +91,12 @@ def get_by_id(call: APICall, company_id, _):
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call: APICall, company_id, _):
@endpoint("models.get_by_task_id")
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
task_id = call.data["task"]
task_id = request.task
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
@@ -157,7 +161,7 @@ def get_by_id_ex(call: APICall, company_id, _):
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
@endpoint("models.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call.data)
@@ -236,15 +240,15 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
)
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _):
@endpoint("models.update_for_task")
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
override_model_id = call.data.get("override_model_id")
task_id = request.task
uri = request.uri
iteration = request.iteration
override_model_id = request.override_model_id
if not (uri or override_model_id) or (uri and override_model_id):
raise errors.bad_request.MissingRequiredFields(
"exactly one field is required", fields=("uri", "override_model_id")
@@ -411,9 +415,9 @@ def validate_task(company_id: str, identity: Identity, fields: dict):
)
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
@endpoint("models.edit", response_data_model=UpdateResponse)
def edit(call: APICall, company_id, request: UpdateModelRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
@@ -428,7 +432,7 @@ def edit(call: APICall, company_id, _):
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
iteration = request.iteration
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
@@ -460,13 +464,9 @@ def edit(call: APICall, company_id, _):
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
def _update_model(call: APICall, company_id, model_id):
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
@@ -502,11 +502,9 @@ def _update_model(call: APICall, company_id, model_id=None):
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call, company_id, _):
call.result.data_model = _update_model(call, company_id)
@endpoint("models.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UpdateModelRequest):
call.result.data_model = _update_model(call, company_id, model_id=request.model)
@endpoint(
@@ -629,7 +627,9 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
func=partial(
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
)
call.result.data_model = BatchResponse(

View File

@@ -67,6 +67,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
status_message="",
status_reason="Pipeline run deleted",
delete_external_artifacts=True,
include_pipeline_steps=True,
),
ids=list(ids),
)

View File

@@ -59,13 +59,12 @@ create_fields = {
}
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
@endpoint("projects.get_by_id")
def get_by_id(call: APICall, company: str, request: ProjectRequest):
project_id = request.project
with translate_errors_context():
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
query = Q(id=project_id) & get_company_or_none_constraint(company)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@@ -147,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
data,
shallow_search=request.shallow_search,
)
selected_project_ids = None
if request.active_users or request.children_type:
@@ -246,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(
company=company_id, project_ids=project_ids, users=request.active_users,
company=company_id,
project_ids=project_ids,
users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
@@ -255,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
@endpoint("projects.get_all")
def get_all(call: APICall):
def get_all(call: APICall, company: str, _):
data = call.data
conform_tag_fields(call, data)
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
data,
shallow_search=data.get("shallow_search", False),
)
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
company=company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
@@ -277,9 +281,11 @@ def get_all(call: APICall):
@endpoint(
"projects.create", required_fields=["name"], response_data_model=IdResponse,
"projects.create",
required_fields=["name"],
response_data_model=IdResponse,
)
def create(call: APICall):
def create(call: APICall, company: str, _):
identity = call.identity
with translate_errors_context():
@@ -288,15 +294,15 @@ def create(call: APICall):
return IdResponse(
id=ProjectBLL.create(
user=identity.user, company=identity.company, **fields,
user=identity.user,
company=company,
**fields,
)
)
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call: APICall):
@endpoint("projects.update", response_data_model=UpdateResponse)
def update(call: APICall, company: str, request: ProjectRequest):
"""
update
@@ -309,9 +315,7 @@ def update(call: APICall):
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields, validate=True)
updated = ProjectBLL.update(
company=call.identity.company, project_id=call.data["project"], **fields
)
updated = ProjectBLL.update(company=company, project_id=request.project, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -375,7 +379,6 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
def get_unique_metric_variants(
call: APICall, company_id: str, request: GetUniqueMetricsRequest
):
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
@@ -429,7 +432,6 @@ def get_model_metadata_values(
request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,

View File

@@ -3,7 +3,11 @@ from datetime import datetime
from pyhocon.config_tree import NoneValue
from apiserver.apierrors import errors
from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
from apiserver.apimodels.server import (
ReportStatsOptionRequest,
ReportStatsOptionResponse,
GetConfigRequest,
)
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
from apiserver.config_repo import config
from apiserver.config.info import get_version, get_build_number, get_commit_number
@@ -22,8 +26,8 @@ def get_stats(call: APICall):
@endpoint("server.config")
def get_config(call: APICall):
path = call.data.get("path")
def get_config(call: APICall, _, request: GetConfigRequest):
path = request.path
if path:
c = dict(config.get(path))
else:

View File

@@ -55,7 +55,6 @@ from apiserver.apimodels.tasks import (
ResetManyRequest,
DeleteManyRequest,
PublishManyRequest,
TaskBatchRequest,
EnqueueManyResponse,
EnqueueBatchItem,
DequeueBatchItem,
@@ -68,6 +67,9 @@ from apiserver.apimodels.tasks import (
DequeueRequest,
DequeueManyRequest,
UpdateTagsRequest,
StopRequest,
UnarchiveManyRequest,
ArchiveManyRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@@ -98,7 +100,6 @@ from apiserver.bll.task.task_operations import (
delete_task,
publish_task,
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import (
update_task,
@@ -266,7 +267,7 @@ def get_by_id_ex(call: APICall, company_id, _):
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
@endpoint("tasks.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call.data)
@@ -295,9 +296,9 @@ def get_types(call: APICall, company_id, request: GetTypesRequest):
@endpoint(
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
"tasks.stop", response_data_model=UpdateResponse
)
def stop(call: APICall, company_id, req_model: UpdateRequest):
def stop(call: APICall, company_id, request: StopRequest):
"""
stop
:summary: Stop a running task. Requires task status 'in_progress' and
@@ -308,12 +309,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
"""
call.result.data_model = UpdateResponse(
**stop_task(
task_id=req_model.task,
task_id=request.task,
company_id=company_id,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
)
)
@@ -332,6 +334,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -364,13 +367,21 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
response_data_model=StartedResponse,
)
def started(call: APICall, company_id, req_model: UpdateRequest):
started_update = {}
if Task.objects(id=req_model.task, started=None).only("id"):
# this is the fix for older versions putting started to None on reset
started_update["started"] = datetime.utcnow()
else:
# don't override a previous, smaller "started" field value
started_update["min__started"] = datetime.utcnow()
res = StartedResponse(
**set_task_status_from_call(
req_model,
company_id=company_id,
identity=call.identity,
new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
**started_update,
)
)
res.started = res.updated
@@ -523,7 +534,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
@endpoint("tasks.validate", request_data_model=CreateRequest)
def validate(call: APICall, company_id, req_model: CreateRequest):
def validate(call: APICall, _, __: CreateRequest):
parent = call.data.get("parent")
if parent and parent.startswith(deleted_prefix):
call.data.pop("parent")
@@ -547,7 +558,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
def create(call: APICall, company_id, _: CreateRequest):
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context():
@@ -1019,14 +1030,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
)
res = ResetResponse(**updates, dequeued=dequeued)
# do not return artifacts since they are not serializable
res.fields.pop("execution.artifacts", None)
for key, value in attr.asdict(cleanup_res).items():
setattr(res, key, value)
res = ResetResponse(**updates, **attr.asdict(cleanup_res), dequeued=dequeued)
call.result.data_model = res
@@ -1050,23 +1054,19 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
ids=request.ids,
)
def clean_res(res: dict) -> dict:
# do not return artifacts since they are not serializable
fields = res.get("fields")
if fields:
fields.pop("execution.artifacts", None)
return res
call.result.data_model = ResetManyResponse(
succeeded=[
succeeded = []
for _id, (dequeued, cleanup, res) in results:
succeeded.append(
ResetBatchItem(
id=_id,
dequeued=bool(dequeued.get("removed")) if dequeued else False,
**attr.asdict(cleanup),
**clean_res(res),
**res,
)
for _id, (dequeued, cleanup, res) in results
],
)
call.result.data_model = ResetManyResponse(
succeeded=succeeded,
failed=failures,
)
@@ -1090,6 +1090,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"project",
"system_tags",
"enqueue_status",
"type",
),
)
for task in tasks:
@@ -1099,6 +1100,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
)
call.result.data_model = ArchiveResponse(archived=archived)
@@ -1106,10 +1108,9 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
@endpoint(
"tasks.archive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
archive_task,
@@ -1117,6 +1118,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -1128,10 +1130,9 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
@endpoint(
"tasks.unarchive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
def unarchive_many(call: APICall, company_id, request: UnarchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
unarchive_task,
@@ -1139,6 +1140,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -1163,10 +1165,9 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
)
if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@@ -1185,15 +1186,12 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))

View File

@@ -7,12 +7,16 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest
from apiserver.apimodels.users import (
CreateRequest,
SetPreferencesRequest,
UserRequest,
)
from apiserver.bll.project import ProjectBLL
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import Role
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.company import Company
from apiserver.database.model.user import User
from apiserver.database.utils import parse_from_call
@@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None):
return res.to_proper_dict()
@endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call: APICall, company_id, _):
user_id = call.data["user"]
@endpoint("users.get_by_id")
def get_by_id(call: APICall, company_id, request: UserRequest):
user_id = request.user
call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[])
@endpoint("users.get_all_ex")
def get_all_ex(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many_with_join(company=company_id, query_dict=call.data)
@@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
@endpoint("users.get_all_ex", min_version="2.8")
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
@@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[])
@endpoint("users.get_all")
def get_all(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many(
@@ -138,9 +142,9 @@ def create(call: APICall):
UserBLL.create(call.data_model)
@endpoint("users.delete", required_fields=["user"])
def delete(call: APICall):
UserBLL.delete(call.data["user"])
@endpoint("users.delete")
def delete(_: APICall, __, request: UserRequest):
UserBLL.delete(request.user)
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
@@ -154,14 +158,22 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
update_fields = {
k: v for k, v in create_fields.items() if k in User.user_set_allowed()
}
auth_user_update_fields = ("name",)
partial_update_dict = parse_from_call(data, update_fields, User.get_fields())
with translate_errors_context("updating user"):
return User.safe_update(company_id, user_id, partial_update_dict)
ret = User.safe_update(company_id, user_id, partial_update_dict)
auth_update = {
k: v for k, v in partial_update_dict.items() if k in auth_user_update_fields
}
if auth_update:
AuthUser.objects(id=user_id).update(**auth_update)
return ret
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _):
user_id = call.data["user"]
@endpoint("users.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UserRequest):
user_id = request.user
update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)

View File

@@ -47,6 +47,7 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
worker_pattern=request.worker_pattern,
)
)
@@ -61,6 +62,7 @@ def get_all(call: APICall, company_id: str, request: GetCountRequest):
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
worker_pattern=request.worker_pattern,
)
}

View File

@@ -21,7 +21,7 @@ def distributed_lock(name: str, timeout: int, max_wait: int = 0):
max_wait = max_wait or timeout * 2
pid = os.getpid()
while _redis.set(lock_name, value=pid, ex=timeout, nx=True) is None:
sleep(1)
sleep(0.1)
if time.time() - start > max_wait:
holder = _redis.get(lock_name)
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds. The lock is hold by {holder}")

View File

@@ -5,6 +5,45 @@ from apiserver.tests.automated import TestService
class TestPipelines(TestService):
def test_controller_operations(self):
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
steps = [
self.api.tasks.create(
name=f"Pipeline step {i}",
project=project,
type="training",
system_tags=["pipeline"],
parent=task
).id
for i in range(2)
]
ids = [task, *steps]
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), len(ids))
# stop
partial_ids = [task, steps[0]]
self.api.tasks.enqueue_many(ids=partial_ids)
res = self.api.tasks.get_all_ex(id=partial_ids, search_hidden=True)
self.assertTrue(t.stats == "in_progress" for t in res.tasks)
self.api.tasks.stop(task=task, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertTrue(t.stats == "created" for t in res.tasks)
# archive/unarchive
self.api.tasks.archive(tasks=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), 0)
self.api.tasks.unarchive_many(ids=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), len(ids))
# delete
self.api.tasks.delete(task=task, force=True, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), 0)
def test_delete_runs(self):
queue = self.api.queues.get_default().id
task_name = "pipelines test"

View File

@@ -113,7 +113,7 @@ class TestProjectTags(TestService):
new_tags = ["New model tag"]
self.api.models.update_tags(ids=[model], add_tags=new_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
self.assertEqual(set(data.tags), {*new_tags, *initial_tags})
def new_task(self, **kwargs):
self.update_missing(

View File

@@ -193,33 +193,33 @@ class TestTaskEvents(TestService):
def test_last_scalar_metrics(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
# send 2 batches to check the interaction with already stored db value
# each batch contains multiple iterations
self.send_batch(events[:50])
self.send_batch(events[50:])
for variant in ("Variant1", None):
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
# send 2 batches to check the interaction with already stored db value
# each batch contains multiple iterations
self.send_batch(events[:50])
self.send_batch(events[50:])
task_data = self.api.tasks.get_by_id(task=task).task
metric_data = first(first(task_data.last_metrics.values()).values())
self.assertEqual(iter_count - 1, metric_data.value)
self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
task_data = self.api.tasks.get_by_id(task=task).task
metric_data = first(first(task_data.last_metrics.values()).values())
self.assertEqual(iter_count - 1, metric_data.value)
self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
res = self.api.events.get_task_latest_scalar_values(task=task)
self.assertEqual(iter_count - 1, res.last_iter)
res = self.api.events.get_task_latest_scalar_values(task=task)
self.assertEqual(iter_count - 1, res.last_iter)
def test_model_events(self):
model = self._temp_model(ready=False)
@@ -346,6 +346,34 @@ class TestTaskEvents(TestService):
# test order
self._assert_log_events(task=task, order="asc")
metric = "metric"
variant = "variant"
events = [
self._create_task_event(
"log",
task=task,
iteration=iter_,
timestamp=timestamp + iter_ * 1000,
msg=f"This is a log message from test task iter {iter_}",
metric=metric,
variant=variant,
)
for iter_ in range(2)
]
self.send_batch(events)
res = self.api.events.get_task_log(task=task)
self.assertEqual(res.total, 12)
res = self.api.events.get_task_log(task=task, metrics=[{"metric": metric}])
self.assertEqual(res.total, 2)
# test clear
self.api.events.clear_task_log(task=task, exclude_metrics=[metric])
res = self.api.events.get_task_log(task=task)
self.assertEqual(res.total, 2)
self.api.events.clear_task_log(task=task)
res = self.api.events.get_task_log(task=task)
self.assertEqual(res.total, 0)
def _assert_log_events(
self,
task,

View File

@@ -32,7 +32,7 @@ class TestWorkersService(TestService):
self.api.workers.register(worker=w, system_tags=[system_tag])
# total workers count include the new ones
count = self.api.workers.get_count().count
self.assertGreater(count, len(test_workers))
self.assertGreaterEqual(count, len(test_workers))
# filter by system tag and last seen
count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count
self.assertEqual(count, len(test_workers))

View File

@@ -7,13 +7,17 @@ from humanfriendly import parse_timespan
def setup():
from apiserver.database import db
db.initialize()
def gen_token(args):
from apiserver.bll.auth import AuthBLL
resp = AuthBLL.get_token_for_user(args.user_id, args.company_id, parse_timespan(args.expiration))
print('Token:\n%s' % resp.token)
resp = AuthBLL.get_token_for_user(
args.user_id, args.company_id, int(parse_timespan(args.expiration))
)
print("Token:\n%s" % resp.token)
def safe_get(obj, glob, default=None, separator="/"):
@@ -23,19 +27,24 @@ def safe_get(obj, glob, default=None, separator="/"):
return default
if __name__ == '__main__':
if __name__ == "__main__":
top_parser = ArgumentParser(__doc__)
subparsers = top_parser.add_subparsers(title='Sections')
subparsers = top_parser.add_subparsers(title="Sections")
token = subparsers.add_parser('token')
token_commands = token.add_subparsers(title='Commands')
token_create = token_commands.add_parser('generate', description='Generate a new token')
token_create.add_argument('--user-id', '-u', help='User ID', required=True)
token_create.add_argument('--company-id', '-c', help='Company ID', required=True)
token_create.add_argument('--expiration', '-exp',
help="Token expiration (time span, shorthand suffixes are supported, default 1m)",
default=parse_timespan('1m'))
token = subparsers.add_parser("token")
token_commands = token.add_subparsers(title="Commands")
token_create = token_commands.add_parser(
"generate", description="Generate a new token"
)
token_create.add_argument("--user-id", "-u", help="User ID", required=True)
token_create.add_argument("--company-id", "-c", help="Company ID", required=True)
token_create.add_argument(
"--expiration",
"-exp",
help="Token expiration (time span, shorthand suffixes are supported, default 1m)",
default=parse_timespan("1m"),
)
token_create.set_defaults(_func=gen_token)
args = top_parser.parse_args()

View File

@@ -1 +1 @@
__version__ = "1.14.0"
__version__ = "1.16.0"

View File

@@ -1,17 +1,19 @@
FROM node:18-bullseye as webapp_builder
FROM node:20-bookworm-slim as webapp_builder
ARG CLEARML_WEB_GIT_URL=https://github.com/allegroai/clearml-web.git
USER root
WORKDIR /opt
RUN apt-get update && apt-get install -y git
RUN git clone ${CLEARML_WEB_GIT_URL} clearml-web
RUN mv clearml-web /opt/open-webapp
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
FROM python:3.9-slim-bullseye
FROM python:3.9-slim-bookworm
COPY --chmod=744 docker/build/internal_files/entrypoint.sh /opt/clearml/
COPY --chmod=744 docker/build/internal_files/update_from_env.py /opt/clearml/utilities/
COPY fileserver /opt/clearml/fileserver/
COPY apiserver /opt/clearml/apiserver/

View File

@@ -29,7 +29,12 @@ server {
include /etc/nginx/default.d/*.conf;
location / {
try_files $uri$args $uri$args/ $uri index.html /index.html;
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always;
add_header Content-Security-Policy "frame-ancestors 'self'";
add_header X-XSS-Protection "1; mode=block";
add_header X-Content-Type-Options "nosniff" always;
add_header Referrer-Policy "no-referrer-when-downgrade";
try_files $uri $uri/ /index.html;
}
location /version.json {
@@ -50,6 +55,12 @@ server {
rewrite /files/(.*) /$1 break;
}
location /widgets {
alias /usr/share/nginx/widgets;
try_files $uri $uri/ /widgets/index.html;
add_header Content-Security-Policy "frame-ancestors *";
}
error_page 404 /404.html;
location = /40x.html {
}
@@ -57,4 +68,4 @@ server {
error_page 500 502 503 504 /50x.html;
location = /50x.html {
}
}
}

View File

@@ -46,10 +46,26 @@ elif [[ ${SERVER_TYPE} == "webserver" ]]; then
EOF
fi
# Create an empty configuration json
echo "{}" > /tmp/configuration.json
# Copy the external configuration file if it exists
if test -f "/mnt/external_files/configs/configuration.json"; then
echo "Copying external configuration"
cp /mnt/external_files/configs/configuration.json /tmp/configuration.json
fi
# Update from env variables
echo "Updating configuration from env"
/opt/clearml/utilities/update_from_env.py \
--verbose \
/tmp/configuration.json \
/usr/share/nginx/html/configuration.json
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
COMMENT_IPV6_LISTEN=$([ "$DISABLE_NGINX_IPV6" = "true" ] && echo "#" || echo "") \
envsubst '${COMMENT_IPV6_LISTEN} ${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/sites-enabled/default
export COMMENT_IPV6_LISTEN=$([ "$DISABLE_NGINX_IPV6" = "true" ] && echo "#" || echo "")
envsubst '${COMMENT_IPV6_LISTEN} ${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/sites-enabled/default
if [[ -n "${CLEARML_SERVER_SUB_PATH}" ]]; then
mkdir -p /etc/nginx/default.d/

View File

@@ -1,11 +1,11 @@
#!/usr/bin/env bash
set -x
set -o errexit
set -o nounset
set -o pipefail
apt-get update -y
apt-get install -y python3-setuptools python3-dev build-essential nginx gettext
apt-get install -y vim curl
apt-get install -y python3-setuptools python3-dev build-essential nginx gettext vim curl
python3 -m ensurepip
python3 -m pip install --upgrade pip

View File

@@ -0,0 +1,104 @@
#!/usr/bin/env python3
""" Update json configuration file from environment variables """
from argparse import ArgumentParser, FileType
import json
from os import environ
from typing import Any, Generator, Tuple, Optional, List
class PathConflictError(Exception):
def __init__(self, path_: List[str]):
self.path = path_
def scan(
obj: Any, path_: str = None, sep: str = ".", parent_=None, key_=None,
) -> Generator[Tuple[str, Any, Optional[dict], str], None, None]:
if not isinstance(obj, dict):
yield path_.lower(), obj, parent_, key_
else:
for k, v in obj.items():
yield from scan(v, path_=sep.join(filter(None, (path_, k))), parent_=obj, key_=k, sep=sep)
def set_path(p: List[str], obj: dict, v: Any):
key_, *rest = p
if not rest:
obj[key_] = v
else:
if key_ in obj:
if not isinstance(obj[key_], dict):
raise PathConflictError(rest)
else:
obj[key_] = {}
return set_path(rest, obj[key_], v)
if __name__ == '__main__':
parser = ArgumentParser(description=__doc__)
parser.add_argument("input_file", type=FileType(), help="Input JSON file")
parser.add_argument("output_file", type=FileType("w"), help="Output JSON file")
parser.add_argument(
"--env-prefix", "-p", default="WEBSERVER", help="Environment variables prefix (default=%(default)s)",
dest="prefix", required=False
)
parser.add_argument(
"--env-separator", "-s", default="__", help="Environment variable name separator (default=%(default)s)",
dest="sep"
)
parser.add_argument("--verbose", "-v", action="store_true", default=False)
parser.add_argument(
"--disable-parse-env-value", action="store_false", default=True, help="Don't parse env value as JSON",
dest="parse_env"
)
args = parser.parse_args()
if not args.prefix:
print("Error: script does not support an empty prefix")
exit(1)
data = None
try:
data = json.load(args.input_file)
except json.JSONDecodeError as ex:
print(f"Error parsing JSON file {args.input_file.name}: {str(ex)}")
exit(1)
def parse_value(k, v):
try:
return json.loads(v)
except json.JSONDecodeError as ex:
print(f"Error parsing {k} JSON value `{v}`: {str(ex)}")
exit(2)
prefix = args.prefix + args.sep
env_vars = {
k.lstrip(prefix): parse_value(k, v) if args.parse_env else v
for k, v in environ.items() if k.startswith(prefix)
}
for path, value, parent, key in scan(data, sep=args.sep):
if not (parent and key):
continue
match = next((k for k in env_vars if k.lower() == path), None)
if match:
replace = env_vars.pop(match)
parent[key] = replace
if args.verbose:
print(f"Replacing {path}={value} with {replace}")
for k, v in env_vars.items():
path = k.split(args.sep)
try:
set_path(path, data, v)
except PathConflictError as ex:
print(f"Error: failed setting value into {k}: {path[:-len(ex.path)]} is not a dictionary")
try:
json.dump(data, args.output_file, sort_keys=True, indent=2)
except Exception as ex:
print(f"Error writing JSON file {args.output_file.name}: {str(ex)}")
exit(3)

View File

@@ -19,12 +19,11 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
CLEARML_SERVER_DEPLOYMENT_TYPE: ${CLEARML_SERVER_DEPLOYMENT_TYPE:-win10}
CLEARML_SERVER_DEPLOYMENT_TYPE: win10
CLEARML__apiserver__pre_populate__enabled: "true"
CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
@@ -41,21 +40,16 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: clearml
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
reindex.remote.whitelist: "'*.*'"
xpack.security.enabled: "false"
ulimits:
memlock:
@@ -64,7 +58,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
restart: unless-stopped
volumes:
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -93,7 +87,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:4.4.9
image: mongo:4.4.29
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -104,7 +98,7 @@ services:
networks:
- backend
container_name: clearml-redis
image: redis:5.0
image: redis:6.2
restart: unless-stopped
volumes:
- c:/opt/clearml/data/redis:/data
@@ -140,7 +134,6 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis

View File

@@ -19,17 +19,18 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
CLEARML_SERVER_DEPLOYMENT_TYPE: ${CLEARML_SERVER_DEPLOYMENT_TYPE:-linux}
CLEARML_SERVER_DEPLOYMENT_TYPE: linux
CLEARML__apiserver__pre_populate__enabled: "true"
CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
CLEARML__services__async_urls_delete__enabled: "true"
CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
CLEARML__secure__credentials__services_agent__user_key: ${CLEARML_AGENT_ACCESS_KEY:-}
CLEARML__secure__credentials__services_agent__user_secret: ${CLEARML_AGENT_SECRET_KEY:-}
ports:
- "8008:8008"
networks:
@@ -41,21 +42,16 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: clearml
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
reindex.remote.whitelist: "'*.*'"
xpack.security.enabled: "false"
ulimits:
memlock:
@@ -64,7 +60,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.7
image: docker.elastic.co/elasticsearch/elasticsearch:7.17.18
restart: unless-stopped
volumes:
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -92,7 +88,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:4.4.9
image: mongo:4.4.29
restart: unless-stopped
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
@@ -103,7 +99,7 @@ services:
networks:
- backend
container_name: clearml-redis
image: redis:5.0
image: redis:6.2
restart: unless-stopped
volumes:
- /opt/clearml/data/redis:/data
@@ -139,7 +135,6 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
@@ -170,8 +165,8 @@ services:
CLEARML_WEB_HOST: ${CLEARML_WEB_HOST:-}
CLEARML_API_HOST: http://apiserver:8008
CLEARML_FILES_HOST: ${CLEARML_FILES_HOST:-}
CLEARML_API_ACCESS_KEY: ${CLEARML_API_ACCESS_KEY:-}
CLEARML_API_SECRET_KEY: ${CLEARML_API_SECRET_KEY:-}
CLEARML_API_ACCESS_KEY: ${CLEARML_AGENT_ACCESS_KEY:-$CLEARML_API_ACCESS_KEY}
CLEARML_API_SECRET_KEY: ${CLEARML_AGENT_SECRET_KEY:-$CLEARML_API_SECRET_KEY}
CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}

View File

@@ -1,116 +0,0 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: trains-apiserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/opt/trains/logs:/var/log/trains
- c:/opt/trains/config:/opt/trains/config
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
TRAINS_ELASTIC_SERVICE_PORT: 9200
TRAINS_MONGODB_SERVICE_HOST: mongo
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-win10}
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
ports:
- "8008:8008"
networks:
- backend
elasticsearch:
networks:
- backend
container_name: trains-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
bootstrap.memory_lock: "true"
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
fileserver:
networks:
- backend
command:
- fileserver
container_name: trains-fileserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/opt/trains/logs:/var/log/trains
- c:/opt/trains/data/fileserver:/mnt/fileserver
- c:/opt/trains/config:/opt/trains/config
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- c:/opt/trains/data/mongo/db:/data/db
- c:/opt/trains/data/mongo/configdb:/data/configdb
redis:
networks:
- backend
container_name: trains-redis
image: redis:5.0
restart: unless-stopped
volumes:
- c:/opt/trains/data/redis:/data
webserver:
command:
- webserver
container_name: trains-webserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/trains/logs:/var/log/trains
depends_on:
- apiserver
ports:
- "8080:80"
networks:
backend:
driver: bridge

View File

@@ -1,153 +0,0 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: trains-apiserver
image: allegroai/clearml:latest
restart: unless-stopped
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/config:/opt/trains/config
- /opt/trains/data/fileserver:/mnt/fileserver
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
TRAINS_ELASTIC_SERVICE_PORT: 9200
TRAINS_MONGODB_SERVICE_HOST: mongo
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
TRAINS__apiserver__pre_populate__enabled: "true"
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
ports:
- "8008:8008"
networks:
- backend
- frontend
elasticsearch:
networks:
- backend
container_name: trains-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
bootstrap.memory_lock: "true"
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
fileserver:
networks:
- backend
command:
- fileserver
container_name: trains-fileserver
image: allegroai/clearml:latest
restart: unless-stopped
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/data/fileserver:/mnt/fileserver
- /opt/trains/config:/opt/trains/config
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- /opt/trains/data/mongo/db:/data/db
- /opt/trains/data/mongo/configdb:/data/configdb
redis:
networks:
- backend
container_name: trains-redis
image: redis:5.0
restart: unless-stopped
volumes:
- /opt/trains/data/redis:/data
webserver:
command:
- webserver
container_name: trains-webserver
image: allegroai/clearml:latest
restart: unless-stopped
depends_on:
- apiserver
ports:
- "8080:80"
networks:
- backend
- frontend
agent-services:
networks:
- backend
container_name: trains-agent-services
image: allegroai/trains-agent-services:latest
restart: unless-stopped
privileged: true
environment:
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
TRAINS_API_HOST: http://apiserver:8008
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
TRAINS_WORKER_ID: "trains-services"
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/trains/agent:/root/.trains
depends_on:
- apiserver
networks:
backend:
driver: bridge
frontend:
driver: bridge

129
fileserver/auth.py Normal file
View File

@@ -0,0 +1,129 @@
import json
from hashlib import sha256
from typing import Optional
import attr
from flask import abort, Response, Request
from redis import StrictRedis
from clearml_agent.backend_api import Session
from werkzeug.exceptions import HTTPException
from config import config
from redis_manager import redman
log = config.logger(__file__)
@attr.s(auto_attribs=True)
class TokenInfo:
company: str
user: str
class FileserverSession(Session):
@property
def client(self):
return "fileserver"
@client.setter
def client(self, _):
# do not allow the base class to override the client
pass
class AuthHandler:
enabled = config.get("fileserver.auth.enabled", False)
_instance = None
@classmethod
def instance(cls):
if not cls.enabled:
return None
if not cls._instance:
cls._instance = cls()
return cls._instance
def __init__(self):
self.session = FileserverSession(
api_key=config.get("secure.credentials.fileserver.user_key"),
secret_key=config.get("secure.credentials.fileserver.user_secret"),
host=config.get("hosts.api_server"),
initialize_logging=False,
)
self.redis: StrictRedis = redman.connection("fileserver")
def _validate_and_get_token_info(self, token: str) -> TokenInfo:
token_hash = sha256(token.encode()).hexdigest() if len(token) > 256 else token
key = f"token_{token_hash}"
token_data = self.redis.get(key)
if token_data:
return TokenInfo(**json.loads(token_data))
try:
res = self.session.send_request(
service="auth", action="validate_token", json={"token": token}
)
if res.status_code == 500:
log.error("Error validating token")
abort(Response(f"Internal error (status={res.status_code})", 500))
elif res.status_code != 200:
log.error("Error validating token")
abort(res.status_code)
data = res.json()["data"]
if not data["valid"]:
log.error(f"Error validating token: {data['msg']}")
abort(Response(data["msg"], 401))
info = TokenInfo(
company=data.get("company", "unknown"),
user=data.get("user"),
)
timeout_sec = config.get(
"fileserver.auth.tokens_cache_threshold_sec", 12 * 60 * 60
)
self.redis.setex(key, time=timeout_sec, value=json.dumps(attr.asdict(info)))
return info
except HTTPException:
raise
except Exception:
log.exception(f"Failed decoding token")
abort(500)
def validate(self, request: Request):
token = self.get_token(request)
if not token:
log.error("Error getting token")
abort(401)
self._validate_and_get_token_info(token)
@staticmethod
def get_token(request: Request) -> Optional[str]:
auth_header = request.headers.get("Authorization")
if auth_header:
if not auth_header.startswith("Bearer "):
log.error("Only bearer token authorization is supported")
abort(
Response("Only bearer token authorization is supported", status=401)
)
token = auth_header.partition(" ")[2]
return token
last_ex = None
for cookie_name in config.get("fileserver.auth.cookie_names", []):
cookie = request.cookies.get(cookie_name)
if not cookie:
continue
try:
return cookie
except HTTPException as ex:
last_ex = ex
if last_ex:
raise last_ex

View File

@@ -10,6 +10,21 @@ delete {
allow_batch: true
}
upload {
# the max size in Mb of the upload contents in one upload call
max_upload_size_mb: 0
}
cors {
origins: "*"
}
auth {
# enable/disable auth validation on upload/download
enabled: true
# names of cookies in which authorization token can be found
cookie_names: ["clearml_token_basic"]
tokens_cache_threshold_sec: 43200
}

View File

@@ -0,0 +1,9 @@
api_server: "http://apiserver:8008"
redis {
fileserver {
host: "redis"
port: 6379
db: 8
}
}

View File

@@ -0,0 +1,7 @@
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
fileserver {
user_key: "GSQWPEKSKNKF354LC9V6BHXKTYFD5I"
user_secret: "tuBXcGQBECsEhcNiK2kiWi750z9r8Z85XrQ9V0c24huTuCb2xf2X1nKG"
}
}

View File

@@ -15,6 +15,7 @@ from flask_cors import CORS
from werkzeug.exceptions import NotFound
from werkzeug.security import safe_join
from auth import AuthHandler
from config import config
from utils import get_env_bool
@@ -34,10 +35,17 @@ app.config["UPLOAD_FOLDER"] = first(
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get(
"fileserver.download.cache_timeout_sec", 5 * 60
)
if max_upload_size := config.get("fileserver.upload.max_upload_size_mb", None):
app.config["MAX_CONTENT_LENGTH"] = max_upload_size * 1024 * 1024
auth_handler = AuthHandler.instance()
@app.route("/", methods=["GET"])
def ping():
if auth_handler and auth_handler.get_token(request):
auth_handler.validate(request)
return "OK", 200
@@ -57,6 +65,9 @@ def after_request(response):
@app.route("/", methods=["POST"])
def upload():
if auth_handler:
auth_handler.validate(request)
results = []
for filename, file in request.files.items():
if not filename:
@@ -76,6 +87,9 @@ def upload():
@app.route("/<path:path>", methods=["GET"])
def download(path):
if auth_handler:
auth_handler.validate(request)
as_attachment = "download" in request.args
_, encoding = mimetypes.guess_type(os.path.basename(path))
@@ -105,18 +119,24 @@ def _get_full_path(path: str) -> Path:
@app.route("/<path:path>", methods=["DELETE"])
def delete(path):
real_path = _get_full_path(path)
if not real_path.exists() or not real_path.is_file():
log.error(f"Error deleting file {str(real_path)}. Not found or not a file")
abort(Response(f"File {str(real_path)} not found", 404))
if auth_handler:
auth_handler.validate(request)
real_path.unlink()
full_path = _get_full_path(path)
if not full_path.exists() or not full_path.is_file():
log.error(f"Error deleting file {str(full_path)}. Not found or not a file")
abort(Response(f"File {str(path)} not found", 404))
log.info(f"Deleted file {str(real_path)}")
full_path.unlink()
log.info(f"Deleted file {str(full_path)}")
return json.dumps(str(path)), 200
def batch_delete():
if auth_handler:
auth_handler.validate(request)
body = request.get_json(force=True, silent=False)
if not body:
abort(Response("Json payload is missing", 400))
@@ -139,17 +159,17 @@ def batch_delete():
record_error("Empty path not allowed", file, path)
continue
path = _get_full_path(path)
full_path = _get_full_path(path)
if not path.exists():
if not full_path.exists():
record_error("Not found", file, path)
continue
try:
if path.is_file():
path.unlink()
elif path.is_dir():
shutil.rmtree(path)
if full_path.is_file():
full_path.unlink()
elif full_path.is_dir():
shutil.rmtree(full_path)
else:
record_error("Not a file or folder", file, path)
continue
@@ -157,7 +177,7 @@ def batch_delete():
record_error(ex.strerror, file, path)
continue
except Exception as ex:
record_error(str(ex).replace(str(path), ""), file, path)
record_error(str(ex).replace(str(full_path), ""), file, path)
continue
deleted[file] = str(path)

103
fileserver/redis_manager.py Normal file
View File

@@ -0,0 +1,103 @@
from os import getenv
from boltons.iterutils import first
from redis import StrictRedis
from redis.cluster import RedisCluster
from config import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_REDIS_SERVICE_HOST",
"TRAINS_REDIS_SERVICE_HOST",
"REDIS_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PORT",
"TRAINS_REDIS_SERVICE_PORT",
"REDIS_SERVICE_PORT",
)
OVERRIDE_PASSWORD_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PASSWORD",
"TRAINS_REDIS_SERVICE_PASSWORD",
"REDIS_SERVICE_PASSWORD",
)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
log.info(f"Using override redis host {OVERRIDE_HOST}")
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override redis port {OVERRIDE_PORT}")
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
class ConfigError(Exception):
pass
class GeneralError(Exception):
pass
class RedisManager(object):
def __init__(self, redis_config_dict):
self.aliases = {}
for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict()
alias_config["password"] = config.get(
f"secure.redis.{alias}.password", None
)
is_cluster = alias_config.get("cluster", False)
host = OVERRIDE_HOST or alias_config.get("host", None)
if host:
alias_config["host"] = host
port = OVERRIDE_PORT or alias_config.get("port", None)
if port:
alias_config["port"] = port
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
if password:
alias_config["password"] = password
if not port or not host:
raise ConfigError(
f"Redis configuration is invalid. missing port or host (alias={alias})"
)
if is_cluster:
del alias_config["cluster"]
del alias_config["db"]
self.aliases[alias] = RedisCluster(**alias_config)
else:
self.aliases[alias] = StrictRedis(**alias_config)
def connection(self, alias) -> StrictRedis:
obj = self.aliases.get(alias)
if not obj:
raise GeneralError(f"Invalid Redis alias {alias}")
obj.get("health")
return obj
def host(self, alias):
r = self.connection(alias)
if isinstance(r, RedisCluster):
connections = r.get_default_node().redis_connection.connection_pool._available_connections
else:
connections = r.connection_pool._available_connections
if not connections:
return None
return connections[0].host
redman = RedisManager(config.get("hosts.redis"))

View File

@@ -1,9 +1,11 @@
boltons>=19.1.0
clearml-agent>=1.5.2
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=2.3.3
gunicorn>=20.1.0
pyhocon>=0.3.35
redis>=4.5.4,<5
setuptools>=65.5.1
urllib3>=1.26.18
werkzeug>=3.0.1