Compare commits

24 Commits

Author SHA1 Message Date
allegroai
83a0485518 Fix user credentials reset on apiserver restart 2024-07-17 11:22:52 +03:00
allegroai
f3491cc9b9 Update README 2024-07-07 13:28:40 +03:00
allegroai
7558426bc6 Fix max upload size limit 2024-06-26 11:21:53 +03:00
allegroai
ce01e37c66 Refactor docker compose files: remove legacy, add services agent initialization in Linux 2024-06-26 10:53:43 +03:00
allegroai
92b42d66b7 Remove default credentials and reset existing credentials if none were provided 2024-06-26 10:52:42 +03:00
allegroai
f7d36bea4f Use an auth token in async_urls_delete when contacting the fileserver 2024-06-20 18:00:19 +03:00
allegroai
f1c876089b Add worker_pattern parameter to workers.get_all and get_count endpoints 2024-06-20 17:59:28 +03:00
allegroai
dd0ecb712d Added fileserver.upload.max_upload_size_mb setting 2024-06-20 17:58:33 +03:00
allegroai
fcfc1e8998 Support a more granular distributed lock wait 2024-06-20 17:57:54 +03:00
allegroai
9c210bb4fa Fix fixed users creation/removal 2024-06-20 17:57:23 +03:00
allegroai
14547155cb Delete pipeline steps in pipelines.delete_runs 2024-06-20 17:55:52 +03:00
allegroai
3f34f83a91 Version bump to 1.16.0
API version bump to 2.30
Add missing endpoints to schema
2024-06-20 17:55:17 +03:00
allegroai
da3941e6f2 Upgrade pymongo dependency 2024-06-20 17:53:15 +03:00
allegroai
2e19a18ee4 Support automatic handling of pipeline steps if a pipeline controller task ID was passed to one of the tasks endpoints 2024-06-20 17:52:46 +03:00
allegroai
cdc668e3c8 Fileserver authorization is enabled by default 2024-06-20 17:50:02 +03:00
allegroai
7c9889605a Add token authorization to fileserver 2024-06-20 17:48:54 +03:00
allegroai
5456ee4ebf Data tool export projects by name now includes subprojects + option for exporting all projects added 2024-06-20 17:48:18 +03:00
allegroai
562cb77003 Support getting and clearing task logs using specific metrics 2024-06-20 17:47:39 +03:00
allegroai
91df2bb3b7 Use better token generation for the secret key 2024-06-20 17:46:23 +03:00
allegroai
cb9812caee Do not return any mongodb instructions as a result of task update operations 2024-06-20 17:44:17 +03:00
allegroai
0496582d96 Ensure min interval on workers history charts so that we do not get "saw like" chart due to the missing points in the intervals 2024-06-20 17:43:52 +03:00
allegroai
beff19e104 Fix do not return full file path on errors from the fileserver 2024-06-20 17:43:19 +03:00
pollfly
639b3d59a4 Update docstrings (#246)
Edit description so they can be rendered using MDX
2024-06-20 17:00:31 +03:00
allegroai
c0d687e2ef Fix missing git in Dockerfile for building webapp 2024-03-28 17:50:35 +02:00
46 changed files with 944 additions and 543 deletions

View File

@@ -6,7 +6,7 @@
</br>Experiment Manager, ML-Ops and Data-Management**
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
[![Python versions](https://img.shields.io/badge/python-3.9-blue.svg)](https://img.shields.io/badge/python-3.9-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/allegroai)](https://artifacthub.io/packages/search?repo=allegroai)

View File

@@ -146,6 +146,7 @@ class LogEventsRequest(TaskEventsRequestBase):
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
@@ -229,3 +230,5 @@ class ClearTaskLogRequest(Base):
task: str = StringField(required=True)
threshold_sec = IntField()
allow_locked = BoolField(default=False)
exclude_metrics = ListField(items_types=[str])
include_metrics = ListField(items_types=[str])

View File

@@ -101,6 +101,10 @@ class DequeueRequest(UpdateRequest):
new_status = StringField()
class StopRequest(UpdateRequest):
include_pipeline_steps = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
@@ -112,6 +116,7 @@ class DeleteRequest(UpdateRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class SetRequirementsRequest(TaskRequest):
@@ -264,6 +269,7 @@ class DeleteConfigurationRequest(TaskUpdateRequest):
class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
include_pipeline_steps = BoolField(default=False)
class ArchiveResponse(models.Base):
@@ -275,8 +281,17 @@ class TaskBatchRequest(BatchRequest):
status_message = StringField(default="")
class ArchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class UnarchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
include_pipeline_steps = BoolField(default=False)
class DequeueManyRequest(TaskBatchRequest):
@@ -297,6 +312,7 @@ class DeleteManyRequest(TaskBatchRequest):
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):

View File

@@ -100,6 +100,7 @@ class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
system_tags = ListField(str)
worker_pattern = StringField()
class GetAllResponse(Base):

View File

@@ -1227,6 +1227,8 @@ class EventBLL(object):
task_id: str,
allow_locked: bool = False,
threshold_sec: int = None,
include_metrics: Sequence[str] = None,
exclude_metrics: Sequence[str] = None,
):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
@@ -1251,8 +1253,16 @@ class EventBLL(object):
}
)
sort = {"timestamp": {"order": "desc"}}
if include_metrics:
must.append({"terms": {"metric": include_metrics}})
more_conditions = {}
if exclude_metrics:
more_conditions = {"must_not": [{"terms": {"metric": exclude_metrics}}]}
es_req = {
"query": {"bool": {"must": must}},
"query": {"bool": {"must": must, **more_conditions}},
**({"sort": sort} if sort else {}),
}
es_res = delete_company_events(

View File

@@ -22,7 +22,7 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
DEFAULT_LAST_ITERATION, TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
@@ -32,54 +32,79 @@ log = config.logger(__file__)
queue_bll = QueueBLL()
def _get_pipeline_steps_for_controller_task(
task: Task, company_id: str, only: Sequence[str] = None
) -> Sequence[Task]:
if not task or task.type != TaskType.controller:
return []
query = Task.objects(company=company_id, parent=task.id)
if only:
query = query.only(*only)
return list(query)
def archive_task(
task: Union[str, Task],
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
user_id = identity.user
fields = (
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
"type",
)
if isinstance(task, str):
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
only=fields,
)
user_id = identity.user
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
def archive_task_core(task_: Task) -> int:
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task_.update(
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
archive_task_core(step)
return archive_task_core(task)
def unarchive_task(
@@ -88,24 +113,36 @@ def unarchive_task(
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
fields = ("id", "type")
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=("id",),
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
only=fields,
)
def unarchive_task_core(task_: Task) -> int:
return task_.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
unarchive_task_core(step)
return unarchive_task_core(task)
def dequeue_task(
task_id: str,
@@ -262,6 +299,7 @@ def delete_task(
status_message: str,
status_reason: str,
delete_external_artifacts: bool,
include_pipeline_steps: bool,
) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user
task = get_task_with_write_access(
@@ -280,36 +318,51 @@ def delete_task(
current=task.status,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
def delete_task_core(task_: Task, force_: bool):
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
res = cleanup_task(
company=company_id,
user=user_id,
task=task_,
force=force_,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task(
company=company_id,
user=user_id,
task=task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task_.last_update = datetime.utcnow()
task_.save()
else:
task_.delete()
return res
task_ids = [task.id]
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.last_update = datetime.utcnow()
task.save()
else:
task.delete()
move_tasks_to_trash(task_ids)
update_project_time(task.project)
return 1, task, cleanup_res
@@ -465,6 +518,7 @@ def stop_task(
user_name: str,
status_reason: str,
force: bool,
include_pipeline_steps: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
@@ -475,19 +529,21 @@ def stop_task(
:return: updated task fields
"""
user_id = identity.user
fields = (
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
"type",
)
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
only=fields,
)
def is_run_by_worker(t: Task) -> bool:
@@ -499,32 +555,41 @@ def stop_task(
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
is_queued = task.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task.system_tags
or not is_run_by_worker(task)
)
def stop_task_core(task_: Task, force_: bool):
is_queued = task_.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task_.system_tags
or not is_run_by_worker(task_)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task_.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute()
return ChangeStatusRequest(
task=task_,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force_,
user_id=user_id,
).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
stop_task_core(step, True)
return stop_task_core(task, force)

View File

@@ -4,6 +4,7 @@ from typing import Sequence
import attr
import six
from mongoengine import Q
from mongoengine.base import UPDATE_OPERATORS
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
@@ -78,8 +79,16 @@ class ChangeStatusRequest(object):
update_project_time(project_id)
# make sure that _raw_ queries are not returned back to the client
fields.pop("__raw__", None)
def is_mongo_operator(field: str) -> bool:
head, _, tail = field.partition("__")
return tail and (head in UPDATE_OPERATORS)
# make sure to not return _raw_ queries or any of the update operators
fields = {
key: value
for key, value in fields.items()
if not (key == "__raw__" or is_mongo_operator(key))
}
return dict(updated=updated, fields=fields)

View File

@@ -1,4 +1,5 @@
import itertools
import re
from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional
@@ -34,6 +35,8 @@ log = config.logger(__file__)
class WorkerBLL:
_key_regex_trans = str.maketrans({"*": ".*", "?": ".?"})
def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
@@ -207,15 +210,25 @@ class WorkerBLL:
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
):
if not last_seen:
return len(
self._get_keys(company_id, user_tags=tags, system_tags=system_tags)
self._get_keys(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
return len(
self.get_all(
company_id, last_seen=last_seen, tags=tags, system_tags=system_tags
company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
@@ -225,6 +238,7 @@ class WorkerBLL:
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@@ -233,7 +247,12 @@ class WorkerBLL:
:return:
"""
try:
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
workers = self._get(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
@@ -253,6 +272,7 @@ class WorkerBLL:
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerResponseEntry]:
helpers = [
WorkerConversionHelper.from_worker_entry(entry)
@@ -261,6 +281,7 @@ class WorkerBLL:
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
]
@@ -320,7 +341,7 @@ class WorkerBLL:
for helper in helpers:
worker = helper.worker
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
task: Task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration
@@ -416,16 +437,25 @@ class WorkerBLL:
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[bytes]:
if not (user_tags or system_tags):
match = self._get_worker_key(company, user, "*")
match = self._get_worker_key(company, user, worker_pattern or "*")
return list(self.redis.scan_iter(match))
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
return in_keys
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
def filter_by_user_and_pattern(in_keys: Set[bytes]) -> Set[bytes]:
if user != "*":
user_bytes = user.encode()
in_keys = {k for k in in_keys if user_bytes in k}
if worker_pattern:
worker_pattern_bytes = (
f"{worker_pattern.translate(self._key_regex_trans)}$".encode()
)
regex = re.compile(worker_pattern_bytes)
in_keys = {k for k in in_keys if regex.search(k)}
return in_keys
worker_keys = set()
for tags, tags_field in (
@@ -448,7 +478,7 @@ class WorkerBLL:
)
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
tagged_workers = filter_by_user(tagged_workers)
tagged_workers = filter_by_user_and_pattern(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
@@ -462,7 +492,7 @@ class WorkerBLL:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
worker_keys = filter_by_user_and_pattern(worker_keys)
if not worker_keys:
return []
@@ -487,13 +517,18 @@ class WorkerBLL:
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
entries = []
for keys in chunked_iter(
self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
company,
user=user,
user_tags=user_tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
),
1000,
):

View File

@@ -13,6 +13,8 @@ log = config.logger(__file__)
class WorkerStats:
min_chart_interval = config.get("services.workers.min_chart_interval_sec", 40)
def __init__(self, es):
self.es = es
@@ -203,6 +205,7 @@ class WorkerStats:
"""
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
interval = max(interval, self.min_chart_interval)
must = [QueryBuilder.dates_range(from_date, to_date)]
if active_only:

View File

@@ -1,13 +1,13 @@
{
http {
session_secret {
apiserver: "Gx*gB-L2U8!Naqzd#8=7A4&+=In4H(da424H33ZTDQRGF6=FWw"
apiserver: "V8gcW3EneNDcNfO7G_TSUsWe7uLozyacc9_I33o7bxUo8rCN31VLRg"
}
}
auth {
# token sign secret
token_secret: "7E1ua3xP9GT2(cIQOfhjp+gwN6spBeCAmN-XuugYle00I=Wc+u"
token_secret: "Rq8FW84sSqVgq7WvBB_4EzNl9y8z8IGiDXX3C345_a5AZfcwZcwCIA"
}
credentials {
@@ -15,24 +15,29 @@
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
user_secret: "gaOfhDX2-bpkeI7-cwEcaMuGijxaG2UG3jbIvg4DxmVGF0LNI7rgvCb1-ne38IlBo1w"
}
fileserver {
role: "system"
user_key: "GSQWPEKSKNKF354LC9V6BHXKTYFD5I"
user_secret: "tuBXcGQBECsEhcNiK2kiWi750z9r8Z85XrQ9V0c24huTuCb2xf2X1nKG"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
user_secret: "XhkH6a6ds9JBnM_MrahYyYdO-wS2bqFSm8gl-V0UZXH26Ydd6Eyi28TeBEoSr6Z3Bes"
revoke_in_fixed_mode: true
}
services_agent {
role: "admin"
user_key: "P4BMJA7RK3TKBXGSY8OAA1FA8TOD11"
user_secret: "9LsgSfa0SYz0zli1_c500ZcLqanre2xkWOpepyt1w-BKK3_DKPHrtoj3JSHvyy8bIi0"
user_key: ""
user_secret: ""
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
user_secret: "LPEJbGJ6bK4tujQcmrD3i1dbMBDdwUwelVa-LG0K0FFmY9bzH_H0Sw"
revoke_in_fixed_mode: true
}
}

View File

@@ -9,4 +9,5 @@ fileserver {
# Can be in the form <schema>://host:port/path or /path
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
timeout_sec: 300
token_expiration_sec: 600
}

View File

@@ -0,0 +1,5 @@
default_worker_timeout_sec: 600
default_cluster_timeout_sec: 600
# The minimal sampling interval for resource dashboard and worker activity charts
min_chart_interval_sec: 40

View File

@@ -2,6 +2,9 @@
| Release | ApiVersion |
|---------|------------|
| v1.16 | 2.30 |
| v1.15 | 2.29 |
| v1.14 | 2.28 |
| v1.13 | 2.27 |
| v1.12 | 2.26 |
| v1.11 | 2.25 |

View File

@@ -19,7 +19,9 @@ from google.cloud import storage as google_storage
from mongoengine import Q
from mypy_boto3_s3.service_resource import Bucket as AWSBucket
from apiserver.bll.auth import AuthBLL
from apiserver.bll.storage import StorageBLL
from apiserver.config.info import get_default_company
from apiserver.config_repo import config
from apiserver.database import db
from apiserver.database.model.url_to_delete import UrlToDelete, StorageType, DeletionStatus
@@ -200,6 +202,8 @@ class FileserverStorage(Storage):
res_data = res.json()
return list(res_data.get("deleted", {})), res_data.get("errors", {})
token_expiration_sec = conf.get("fileserver.token_expiration_sec", 600)
def __init__(self, company: str, fileserver_host: str = None):
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
self.host = fileserver_host.rstrip("/")
@@ -220,13 +224,6 @@ class FileserverStorage(Storage):
self.company = company
# @classmethod
# def validate_fileserver_access(cls, fileserver_host: str):
# res = requests.get(
# url=fileserver_host
# )
# res.raise_for_status()
@property
def name(self) -> str:
return "Fileserver"
@@ -260,7 +257,13 @@ class FileserverStorage(Storage):
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
host = base
token = AuthBLL.get_token_for_user(
user_id="__apiserver__",
company_id=get_default_company(),
expiration_sec=self.token_expiration_sec,
).token
session = requests.session()
session.headers.update({"Authorization": "Bearer {}".format(token)})
res = session.get(url=host, timeout=self.Client.timeout)
res.raise_for_status()
@@ -285,6 +288,7 @@ class AzureStorage(Storage):
):
raise ValueError("No path found following container name")
# noinspection PyTypeChecker
return os.path.join(*parsed.path.segments[1:])
@staticmethod

View File

@@ -73,7 +73,7 @@ def init_mongo_data():
}
internal_user_emails.add(email.lower())
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke, internal_user=True)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)

View File

@@ -22,6 +22,7 @@ from typing import (
Mapping,
IO,
Callable,
Iterable,
)
from urllib.parse import unquote, urlparse
from uuid import uuid4, UUID, uuid5
@@ -220,6 +221,9 @@ class PrePopulate:
raise ValueError("Invalid task statuses")
file = Path(filename)
if not (experiments or projects):
projects = cls.project_cls.objects(parent=None).scalar("id")
entities = cls._resolve_entities(
experiments=experiments, projects=projects, task_statuses=task_statuses
)
@@ -417,24 +421,50 @@ class PrePopulate:
featured_index = get_index(project)
cls.project_cls.objects(id=project.id).update(featured=featured_index)
@staticmethod
def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
@classmethod
def _resolve_entity_type(
cls, entity_type: Type[mongoengine.Document], ids: Optional[Sequence[str]]
) -> Sequence[Any]:
ids = set(ids)
items = list(cls.objects(id__in=list(ids)))
items = list(entity_type.objects(id__in=list(ids)))
resolved = {i.id for i in items}
missing = ids - resolved
for name_candidate in missing:
results = list(cls.objects(name=name_candidate))
if not results:
print(f"ERROR: no match for `{name_candidate}`")
exit(1)
elif len(results) > 1:
print(f"ERROR: more than one match for `{name_candidate}`")
exit(1)
items.append(results[0])
return items
if not missing:
return items
resolved_by_name = defaultdict(list)
for entity in entity_type.objects(name__in=list(missing)):
resolved_by_name[entity.name].append(entity)
not_found = missing - set(resolved_by_name)
if not_found:
print(f"ERROR: no match for {', '.join(not_found)}")
exit(1)
duplicates = [k for k, v in resolved_by_name.items() if len(v) > 1]
if duplicates:
print(f"ERROR: more than one match for {', '.join(duplicates)}")
exit(1)
def get_new_items(input_: Iterable) -> list:
return [item for item in input_ if item.id not in resolved]
def get_projects_with_children(projects: list) -> list:
project_ids = set(item.id for item in projects)
ids_with_children = project_ids_with_children(list(project_ids))
if project_ids == set(ids_with_children):
return projects
return get_new_items(entity_type.objects(id__in=ids_with_children))
new_items = get_new_items(chain(*resolved_by_name.values()))
if not new_items:
return items
if entity_type == cls.project_cls:
new_items = get_projects_with_children(new_items)
return items + new_items
@classmethod
def _check_projects_hierarchy(cls, projects: Set[Project]):
@@ -467,7 +497,7 @@ class PrePopulate:
print("Reading projects...")
projects = project_ids_with_children(projects)
entities[cls.project_cls].update(
cls._resolve_type(cls.project_cls, projects)
cls._resolve_entity_type(cls.project_cls, projects)
)
print("--> Reading project experiments...")
query = Q(
@@ -485,7 +515,7 @@ class PrePopulate:
if experiments:
print("Reading experiments...")
entities[cls.task_cls].update(cls._resolve_type(cls.task_cls, experiments))
entities[cls.task_cls].update(cls._resolve_entity_type(cls.task_cls, experiments))
print("--> Reading experiments projects...")
objs = cls.project_cls.objects(
id__in=list(

View File

@@ -9,33 +9,84 @@ from apiserver.database.model.user import User
from apiserver.service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
key, secret = user_data.get("key"), user_data.get("secret")
def _ensure_user_credentials(
user: AuthUser,
key: str,
secret: str,
log: Logger,
revoke: bool = False,
internal_user: bool = False,
) -> None:
if revoke:
log.info(f"Revoking credentials for existing user {user.id} ({user.name})")
user.credentials = []
user.save()
return
if not (key and secret):
credentials = None
else:
creds = Credentials(key=key, secret=secret)
if internal_user:
log.info(f"Resetting credentials for existing user {user.id} ({user.name})")
user.credentials = []
user.save()
return
user = AuthUser.objects(credentials__match=creds).first()
if user:
if revoke:
user.credentials = []
user.save()
return user.id
new_credentials = Credentials(key=key, secret=secret)
if internal_user:
log.info(f"Setting credentials for existing user {user.id} ({user.name})")
user.credentials = [new_credentials]
user.save()
return
credentials = [] if revoke else [creds]
if user.credentials is None:
user.credentials = []
if not any((cred.key, cred.secret) == (key, secret) for cred in user.credentials):
log.info(f"Adding credentials for existing user {user.id} ({user.name})")
user.credentials.append(new_credentials)
user.save()
def _ensure_auth_user(
user_data: dict,
company_id: str,
log: Logger,
revoke: bool = False,
internal_user: bool = False,
) -> str:
user_id = user_data.get("id", f"__{user_data['name']}__")
role = user_data["role"]
email = user_data["email"]
autocreated = user_data.get("autocreated", False)
key, secret = user_data.get("key"), user_data.get("secret")
user: AuthUser = AuthUser.objects(id=user_id).first()
if user:
_ensure_user_credentials(
user=user,
key=key,
secret=secret,
log=log,
revoke=revoke,
internal_user=internal_user,
)
if user.role != role or user.email != email or user.autocreated != autocreated:
user.email = email
user.role = role
user.autocreated = autocreated
user.save()
return user.id
credentials = (
[Credentials(key=key, secret=secret)] if not revoke and key and secret else []
)
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_id,
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
role=role,
email=email,
created=datetime.utcnow(),
credentials=credentials,
autocreated=autocreated,
@@ -62,17 +113,7 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
def ensure_fixed_user(user: FixedUser, log: Logger, emails: set):
db_user = User.objects(company=user.company, id=user.user_id).first()
if db_user:
# noinspection PyBroadException
try:
log.info(f"Updating user name: {user.name}")
given_name, _, family_name = user.name.partition(" ")
db_user.update(name=user.name, given_name=given_name, family_name=family_name)
except Exception:
pass
return
# noinspection PyTypeChecker
data = attr.asdict(user)
data["id"] = user.user_id
email = f"{user.user_id}@example.com"
@@ -81,6 +122,19 @@ def ensure_fixed_user(user: FixedUser, log: Logger, emails: set):
data["autocreated"] = True
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
emails.add(email)
return _ensure_backend_user(user.user_id, user.company, user.name)
db_user = User.objects(company=user.company, id=user.user_id).first()
if db_user:
# noinspection PyBroadException
try:
log.info(f"Updating user name: {user.name}")
given_name, _, family_name = user.name.partition(" ")
db_user.update(
name=user.name, given_name=given_name, family_name=family_name
)
except Exception:
pass
else:
_ensure_backend_user(user.user_id, user.company, user.name)
emails.add(email)

View File

@@ -25,7 +25,7 @@ packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35r
pyjwt>=2.4.0
pymongo==4.4.0
pymongo==4.6.3
python-rapidjson>=0.6.3
redis>=4.5.4,<5
requests>=2.13.0

View File

@@ -947,6 +947,13 @@ get_task_log {
}
}
}
"2.30": ${get_task_log."2.9"} {
request.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_events {
"2.1" {
@@ -1705,4 +1712,18 @@ clear_task_log {
}
}
}
"2.30": ${clear_task_log."2.19"} {
request.properties {
include_metrics {
type: array
description: If passed then only events for these metrics are deleted
items: {type: string}
}
exclude_metrics {
type: array
description: If passed then events for these metrics are retained
items: {type: string}
}
}
}
}

View File

@@ -349,7 +349,7 @@ get_all {
items { type: string }
}
last_update {
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string

View File

@@ -446,7 +446,7 @@ get_task_data {
type: string
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string
@@ -656,7 +656,7 @@ get_all_ex {
items { type: string }
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string

View File

@@ -277,7 +277,7 @@ get_all {
type: string
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (\>, \>=, \<, \<=)"
type: array
items {
type: string
@@ -1107,6 +1107,13 @@ delete_many {
default: true
}
}
"2.30": ${delete_many."2.21"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks the pipeline steps will be also deleted
type: boolean
default: false
}
}
}
delete {
"2.1" {
@@ -1182,6 +1189,13 @@ delete {
default: true
}
}
"2.30": ${delete."2.21"} {
request.properties.include_pipeline_steps {
description: If set then and the passed task is a pipeline controller then delete the pipeline tasks too
type: boolean
default: false
}
}
}
archive {
"2.12" {
@@ -1219,6 +1233,13 @@ archive {
}
}
}
"2.30": ${archive."2.12"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks also archive the pipeline steps
type: boolean
default: false
}
}
}
archive_many {
"2.13": ${_definitions.batch_operation} {
@@ -1245,6 +1266,13 @@ archive_many {
}
}
}
"2.30": ${archive_many."2.13"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks also archive the pipeline steps
type: boolean
default: false
}
}
}
unarchive_many {
"2.13": ${_definitions.batch_operation} {
@@ -1271,6 +1299,13 @@ unarchive_many {
}
}
}
"2.30": ${unarchive_many."2.13"} {
request.properties.include_pipeline_steps {
description: If set then for the passed pipeline controller tasks also archive the pipeline steps
type: boolean
default: false
}
}
}
started {
"2.1" {
@@ -1309,6 +1344,13 @@ stop {
} ${_references.status_change_request}
response: ${_definitions.update_response}
}
"2.30": ${stop."2.1"} {
request.properties.include_pipeline_steps {
description: If set and the passed task is a pipeline controller then stop all its steps too
type: boolean
default: false
}
}
}
stop_many {
"2.13": ${_definitions.change_many_request} {
@@ -1322,6 +1364,13 @@ stop_many {
}
}
}
"2.30": ${stop_many."2.13"} {
request.properties.include_pipeline_steps {
description: If set then for all the passed pipeline controller tasks stop their steps too
type: boolean
default: false
}
}
}
stopped {
"2.1" {

View File

@@ -310,6 +310,12 @@ get_all {
items { type: string }
}
}
"2.30": ${get_all."2.22"} {
request.properties.worker_pattern {
description: The worker name pattern. If specified then only matching keys returned
type: string
}
}
}
get_count {
"2.26": {
@@ -345,6 +351,12 @@ get_count {
}
}
}
"2.30": ${get_count."2.26"} {
request.properties.worker_pattern {
description: The worker name pattern. If specified then only matching keys are counted
type: string
}
}
}
register {
"2.4" {

View File

@@ -1,40 +1,38 @@
import random
import secrets
import string
sys_random = random.SystemRandom()
def get_random_string(length):
"""
Create a random crypto-safe sequence of 'length' or more characters
Possible characters: alphanumeric, '-' and '_'
Make sure that it starts from alphanumeric for better compatibility with yaml files
"""
token = secrets.token_urlsafe(length)
for _ in range(10):
if not (token.startswith("-") or token.startswith("_")):
break
token = secrets.token_urlsafe(length)
return token
def get_random_string(
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
def get_client_id(
length: int = 30, allowed_chars: str = string.ascii_uppercase + string.digits
) -> str:
"""
Returns a securely generated random string.
The default length of 12 with the a-z, A-Z, 0-9 character set returns
a 71-bit value. log_2((26+26+10)^12) =~ 71 bits.
Taken from the django.utils.crypto module.
Create a random client id composed of 'length' upper case characters or digits
"""
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
def get_client_id(length: int = 20) -> str:
"""
Create a random secret key.
Taken from the Django project.
"""
chars = string.ascii_uppercase + string.digits
return get_random_string(length, chars)
return "".join(secrets.choice(allowed_chars) for _ in range(length))
def get_secret_key(length: int = 50) -> str:
"""
Create a random secret key.
Taken from the Django project.
NOTE: asterisk is not supported due to issues with environment variables containing
asterisks (in case the secret key is stored in an environment variable)
Create a random secret key
"""
chars = string.ascii_letters + string.digits
return get_random_string(length, chars)
return get_random_string(length)
if __name__ == "__main__":
print(get_client_id())
print(get_secret_key())

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.29")
_max_version = PartialVersion("2.30")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -296,7 +296,7 @@ class ServiceRepo(object):
except APIError as ex:
# report stack trace only for gene
include_stack = cls._return_stack and cls._should_return_stack(
include_stack = cls._should_return_stack(
ex.code, ex.subcode
)
call.set_error_result(
@@ -310,8 +310,11 @@ class ServiceRepo(object):
pass
except Exception as ex:
log.exception(ex)
include_stack = cls._should_return_stack(
500, 0
)
call.set_error_result(
code=500, subcode=0, msg=str(ex), include_stack=cls._return_stack
code=500, subcode=0, msg=str(ex), include_stack=include_stack
)
finally:
content, content_type = call.get_response()

View File

@@ -172,6 +172,7 @@ def get_task_log(call, company_id, request: LogEventsRequest):
batch_size=request.batch_size,
navigate_earlier=request.navigate_earlier,
from_timestamp=request.from_timestamp,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
if request.order and (
@@ -1041,6 +1042,8 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest)
task_id=task_id,
allow_locked=request.allow_locked,
threshold_sec=request.threshold_sec,
exclude_metrics=request.exclude_metrics,
include_metrics=request.include_metrics,
)
)

View File

@@ -67,6 +67,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
status_message="",
status_reason="Pipeline run deleted",
delete_external_artifacts=True,
include_pipeline_steps=True,
),
ids=list(ids),
)

View File

@@ -55,7 +55,6 @@ from apiserver.apimodels.tasks import (
ResetManyRequest,
DeleteManyRequest,
PublishManyRequest,
TaskBatchRequest,
EnqueueManyResponse,
EnqueueBatchItem,
DequeueBatchItem,
@@ -68,6 +67,9 @@ from apiserver.apimodels.tasks import (
DequeueRequest,
DequeueManyRequest,
UpdateTagsRequest,
StopRequest,
UnarchiveManyRequest,
ArchiveManyRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@@ -98,7 +100,6 @@ from apiserver.bll.task.task_operations import (
delete_task,
publish_task,
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import (
update_task,
@@ -295,9 +296,9 @@ def get_types(call: APICall, company_id, request: GetTypesRequest):
@endpoint(
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
"tasks.stop", response_data_model=UpdateResponse
)
def stop(call: APICall, company_id, req_model: UpdateRequest):
def stop(call: APICall, company_id, request: StopRequest):
"""
stop
:summary: Stop a running task. Requires task status 'in_progress' and
@@ -308,12 +309,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
"""
call.result.data_model = UpdateResponse(
**stop_task(
task_id=req_model.task,
task_id=request.task,
company_id=company_id,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
)
)
@@ -332,6 +334,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -531,7 +534,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
@endpoint("tasks.validate", request_data_model=CreateRequest)
def validate(call: APICall, company_id, req_model: CreateRequest):
def validate(call: APICall, _, __: CreateRequest):
parent = call.data.get("parent")
if parent and parent.startswith(deleted_prefix):
call.data.pop("parent")
@@ -555,7 +558,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
def create(call: APICall, company_id, _: CreateRequest):
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context():
@@ -1027,14 +1030,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
)
res = ResetResponse(**updates, dequeued=dequeued)
# do not return artifacts since they are not serializable
res.fields.pop("execution.artifacts", None)
for key, value in attr.asdict(cleanup_res).items():
setattr(res, key, value)
res = ResetResponse(**updates, **attr.asdict(cleanup_res), dequeued=dequeued)
call.result.data_model = res
@@ -1058,23 +1054,19 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
ids=request.ids,
)
def clean_res(res: dict) -> dict:
# do not return artifacts since they are not serializable
fields = res.get("fields")
if fields:
fields.pop("execution.artifacts", None)
return res
call.result.data_model = ResetManyResponse(
succeeded=[
succeeded = []
for _id, (dequeued, cleanup, res) in results:
succeeded.append(
ResetBatchItem(
id=_id,
dequeued=bool(dequeued.get("removed")) if dequeued else False,
**attr.asdict(cleanup),
**clean_res(res),
**res,
)
for _id, (dequeued, cleanup, res) in results
],
)
call.result.data_model = ResetManyResponse(
succeeded=succeeded,
failed=failures,
)
@@ -1098,6 +1090,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"project",
"system_tags",
"enqueue_status",
"type",
),
)
for task in tasks:
@@ -1107,6 +1100,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
)
call.result.data_model = ArchiveResponse(archived=archived)
@@ -1114,10 +1108,9 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
@endpoint(
"tasks.archive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
archive_task,
@@ -1125,6 +1118,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -1136,10 +1130,9 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
@endpoint(
"tasks.unarchive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
def unarchive_many(call: APICall, company_id, request: UnarchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
unarchive_task,
@@ -1147,6 +1140,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -1171,10 +1165,9 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
)
if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@@ -1193,15 +1186,12 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))

View File

@@ -47,6 +47,7 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
worker_pattern=request.worker_pattern,
)
)
@@ -61,6 +62,7 @@ def get_all(call: APICall, company_id: str, request: GetCountRequest):
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
worker_pattern=request.worker_pattern,
)
}

View File

@@ -21,7 +21,7 @@ def distributed_lock(name: str, timeout: int, max_wait: int = 0):
max_wait = max_wait or timeout * 2
pid = os.getpid()
while _redis.set(lock_name, value=pid, ex=timeout, nx=True) is None:
sleep(1)
sleep(0.1)
if time.time() - start > max_wait:
holder = _redis.get(lock_name)
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds. The lock is hold by {holder}")

View File

@@ -5,6 +5,45 @@ from apiserver.tests.automated import TestService
class TestPipelines(TestService):
def test_controller_operations(self):
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
steps = [
self.api.tasks.create(
name=f"Pipeline step {i}",
project=project,
type="training",
system_tags=["pipeline"],
parent=task
).id
for i in range(2)
]
ids = [task, *steps]
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), len(ids))
# stop
partial_ids = [task, steps[0]]
self.api.tasks.enqueue_many(ids=partial_ids)
res = self.api.tasks.get_all_ex(id=partial_ids, search_hidden=True)
self.assertTrue(t.stats == "in_progress" for t in res.tasks)
self.api.tasks.stop(task=task, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertTrue(t.stats == "created" for t in res.tasks)
# archive/unarchive
self.api.tasks.archive(tasks=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), 0)
self.api.tasks.unarchive_many(ids=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), len(ids))
# delete
self.api.tasks.delete(task=task, force=True, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), 0)
def test_delete_runs(self):
queue = self.api.queues.get_default().id
task_name = "pipelines test"

View File

@@ -346,6 +346,34 @@ class TestTaskEvents(TestService):
# test order
self._assert_log_events(task=task, order="asc")
metric = "metric"
variant = "variant"
events = [
self._create_task_event(
"log",
task=task,
iteration=iter_,
timestamp=timestamp + iter_ * 1000,
msg=f"This is a log message from test task iter {iter_}",
metric=metric,
variant=variant,
)
for iter_ in range(2)
]
self.send_batch(events)
res = self.api.events.get_task_log(task=task)
self.assertEqual(res.total, 12)
res = self.api.events.get_task_log(task=task, metrics=[{"metric": metric}])
self.assertEqual(res.total, 2)
# test clear
self.api.events.clear_task_log(task=task, exclude_metrics=[metric])
res = self.api.events.get_task_log(task=task)
self.assertEqual(res.total, 2)
self.api.events.clear_task_log(task=task)
res = self.api.events.get_task_log(task=task)
self.assertEqual(res.total, 0)
def _assert_log_events(
self,
task,

View File

@@ -32,7 +32,7 @@ class TestWorkersService(TestService):
self.api.workers.register(worker=w, system_tags=[system_tag])
# total workers count include the new ones
count = self.api.workers.get_count().count
self.assertGreater(count, len(test_workers))
self.assertGreaterEqual(count, len(test_workers))
# filter by system tag and last seen
count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count
self.assertEqual(count, len(test_workers))

View File

@@ -1 +1 @@
__version__ = "1.15.0"
__version__ = "1.16.0"

View File

@@ -5,6 +5,7 @@ ARG CLEARML_WEB_GIT_URL=https://github.com/allegroai/clearml-web.git
USER root
WORKDIR /opt
RUN apt-get update && apt-get install -y git
RUN git clone ${CLEARML_WEB_GIT_URL} clearml-web
RUN mv clearml-web /opt/open-webapp
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/

View File

@@ -19,12 +19,11 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
CLEARML_SERVER_DEPLOYMENT_TYPE: ${CLEARML_SERVER_DEPLOYMENT_TYPE:-win10}
CLEARML_SERVER_DEPLOYMENT_TYPE: win10
CLEARML__apiserver__pre_populate__enabled: "true"
CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
@@ -41,8 +40,6 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
@@ -137,7 +134,6 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis

View File

@@ -19,17 +19,18 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
CLEARML_REDIS_SERVICE_PORT: 6379
CLEARML_SERVER_DEPLOYMENT_TYPE: ${CLEARML_SERVER_DEPLOYMENT_TYPE:-linux}
CLEARML_SERVER_DEPLOYMENT_TYPE: linux
CLEARML__apiserver__pre_populate__enabled: "true"
CLEARML__apiserver__pre_populate__zip_files: "/opt/clearml/db-pre-populate"
CLEARML__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
CLEARML__services__async_urls_delete__enabled: "true"
CLEARML__services__async_urls_delete__fileserver__url_prefixes: "[${CLEARML_FILES_HOST:-}]"
CLEARML__secure__credentials__services_agent__user_key: ${CLEARML_AGENT_ACCESS_KEY:-}
CLEARML__secure__credentials__services_agent__user_secret: ${CLEARML_AGENT_SECRET_KEY:-}
ports:
- "8008:8008"
networks:
@@ -41,8 +42,6 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
@@ -136,7 +135,6 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
@@ -167,8 +165,8 @@ services:
CLEARML_WEB_HOST: ${CLEARML_WEB_HOST:-}
CLEARML_API_HOST: http://apiserver:8008
CLEARML_FILES_HOST: ${CLEARML_FILES_HOST:-}
CLEARML_API_ACCESS_KEY: ${CLEARML_API_ACCESS_KEY:-}
CLEARML_API_SECRET_KEY: ${CLEARML_API_SECRET_KEY:-}
CLEARML_API_ACCESS_KEY: ${CLEARML_AGENT_ACCESS_KEY:-$CLEARML_API_ACCESS_KEY}
CLEARML_API_SECRET_KEY: ${CLEARML_AGENT_SECRET_KEY:-$CLEARML_API_SECRET_KEY}
CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}

View File

@@ -1,116 +0,0 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: trains-apiserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/opt/trains/logs:/var/log/trains
- c:/opt/trains/config:/opt/trains/config
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
TRAINS_ELASTIC_SERVICE_PORT: 9200
TRAINS_MONGODB_SERVICE_HOST: mongo
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-win10}
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
ports:
- "8008:8008"
networks:
- backend
elasticsearch:
networks:
- backend
container_name: trains-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
bootstrap.memory_lock: "true"
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
fileserver:
networks:
- backend
command:
- fileserver
container_name: trains-fileserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/opt/trains/logs:/var/log/trains
- c:/opt/trains/data/fileserver:/mnt/fileserver
- c:/opt/trains/config:/opt/trains/config
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- c:/opt/trains/data/mongo/db:/data/db
- c:/opt/trains/data/mongo/configdb:/data/configdb
redis:
networks:
- backend
container_name: trains-redis
image: redis:5.0
restart: unless-stopped
volumes:
- c:/opt/trains/data/redis:/data
webserver:
command:
- webserver
container_name: trains-webserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/trains/logs:/var/log/trains
depends_on:
- apiserver
ports:
- "8080:80"
networks:
backend:
driver: bridge

View File

@@ -1,153 +0,0 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: trains-apiserver
image: allegroai/clearml:latest
restart: unless-stopped
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/config:/opt/trains/config
- /opt/trains/data/fileserver:/mnt/fileserver
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
TRAINS_ELASTIC_SERVICE_HOST: elasticsearch
TRAINS_ELASTIC_SERVICE_PORT: 9200
TRAINS_MONGODB_SERVICE_HOST: mongo
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
TRAINS__apiserver__pre_populate__enabled: "true"
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
ports:
- "8008:8008"
networks:
- backend
- frontend
elasticsearch:
networks:
- backend
container_name: trains-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
bootstrap.memory_lock: "true"
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
cluster.routing.allocation.disk.watermark.low: 500mb
cluster.routing.allocation.disk.watermark.high: 500mb
cluster.routing.allocation.disk.watermark.flood_stage: 500mb
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
fileserver:
networks:
- backend
command:
- fileserver
container_name: trains-fileserver
image: allegroai/clearml:latest
restart: unless-stopped
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/data/fileserver:/mnt/fileserver
- /opt/trains/config:/opt/trains/config
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- /opt/trains/data/mongo/db:/data/db
- /opt/trains/data/mongo/configdb:/data/configdb
redis:
networks:
- backend
container_name: trains-redis
image: redis:5.0
restart: unless-stopped
volumes:
- /opt/trains/data/redis:/data
webserver:
command:
- webserver
container_name: trains-webserver
image: allegroai/clearml:latest
restart: unless-stopped
depends_on:
- apiserver
ports:
- "8080:80"
networks:
- backend
- frontend
agent-services:
networks:
- backend
container_name: trains-agent-services
image: allegroai/trains-agent-services:latest
restart: unless-stopped
privileged: true
environment:
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
TRAINS_API_HOST: http://apiserver:8008
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
TRAINS_WORKER_ID: "trains-services"
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/trains/agent:/root/.trains
depends_on:
- apiserver
networks:
backend:
driver: bridge
frontend:
driver: bridge

129
fileserver/auth.py Normal file
View File

@@ -0,0 +1,129 @@
import json
from hashlib import sha256
from typing import Optional
import attr
from flask import abort, Response, Request
from redis import StrictRedis
from clearml_agent.backend_api import Session
from werkzeug.exceptions import HTTPException
from config import config
from redis_manager import redman
log = config.logger(__file__)
@attr.s(auto_attribs=True)
class TokenInfo:
company: str
user: str
class FileserverSession(Session):
@property
def client(self):
return "fileserver"
@client.setter
def client(self, _):
# do not allow the base class to override the client
pass
class AuthHandler:
enabled = config.get("fileserver.auth.enabled", False)
_instance = None
@classmethod
def instance(cls):
if not cls.enabled:
return None
if not cls._instance:
cls._instance = cls()
return cls._instance
def __init__(self):
self.session = FileserverSession(
api_key=config.get("secure.credentials.fileserver.user_key"),
secret_key=config.get("secure.credentials.fileserver.user_secret"),
host=config.get("hosts.api_server"),
initialize_logging=False,
)
self.redis: StrictRedis = redman.connection("fileserver")
def _validate_and_get_token_info(self, token: str) -> TokenInfo:
token_hash = sha256(token.encode()).hexdigest() if len(token) > 256 else token
key = f"token_{token_hash}"
token_data = self.redis.get(key)
if token_data:
return TokenInfo(**json.loads(token_data))
try:
res = self.session.send_request(
service="auth", action="validate_token", json={"token": token}
)
if res.status_code == 500:
log.error("Error validating token")
abort(Response(f"Internal error (status={res.status_code})", 500))
elif res.status_code != 200:
log.error("Error validating token")
abort(res.status_code)
data = res.json()["data"]
if not data["valid"]:
log.error(f"Error validating token: {data['msg']}")
abort(Response(data["msg"], 401))
info = TokenInfo(
company=data.get("company", "unknown"),
user=data.get("user"),
)
timeout_sec = config.get(
"fileserver.auth.tokens_cache_threshold_sec", 12 * 60 * 60
)
self.redis.setex(key, time=timeout_sec, value=json.dumps(attr.asdict(info)))
return info
except HTTPException:
raise
except Exception:
log.exception(f"Failed decoding token")
abort(500)
def validate(self, request: Request):
token = self.get_token(request)
if not token:
log.error("Error getting token")
abort(401)
self._validate_and_get_token_info(token)
@staticmethod
def get_token(request: Request) -> Optional[str]:
auth_header = request.headers.get("Authorization")
if auth_header:
if not auth_header.startswith("Bearer "):
log.error("Only bearer token authorization is supported")
abort(
Response("Only bearer token authorization is supported", status=401)
)
token = auth_header.partition(" ")[2]
return token
last_ex = None
for cookie_name in config.get("fileserver.auth.cookie_names", []):
cookie = request.cookies.get(cookie_name)
if not cookie:
continue
try:
return cookie
except HTTPException as ex:
last_ex = ex
if last_ex:
raise last_ex

View File

@@ -10,6 +10,21 @@ delete {
allow_batch: true
}
upload {
# the max size in Mb of the upload contents in one upload call
max_upload_size_mb: 0
}
cors {
origins: "*"
}
auth {
# enable/disable auth validation on upload/download
enabled: true
# names of cookies in which authorization token can be found
cookie_names: ["clearml_token_basic"]
tokens_cache_threshold_sec: 43200
}

View File

@@ -0,0 +1,9 @@
api_server: "http://apiserver:8008"
redis {
fileserver {
host: "redis"
port: 6379
db: 8
}
}

View File

@@ -0,0 +1,7 @@
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
fileserver {
user_key: "GSQWPEKSKNKF354LC9V6BHXKTYFD5I"
user_secret: "tuBXcGQBECsEhcNiK2kiWi750z9r8Z85XrQ9V0c24huTuCb2xf2X1nKG"
}
}

View File

@@ -15,6 +15,7 @@ from flask_cors import CORS
from werkzeug.exceptions import NotFound
from werkzeug.security import safe_join
from auth import AuthHandler
from config import config
from utils import get_env_bool
@@ -34,10 +35,17 @@ app.config["UPLOAD_FOLDER"] = first(
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get(
"fileserver.download.cache_timeout_sec", 5 * 60
)
if max_upload_size := config.get("fileserver.upload.max_upload_size_mb", None):
app.config["MAX_CONTENT_LENGTH"] = max_upload_size * 1024 * 1024
auth_handler = AuthHandler.instance()
@app.route("/", methods=["GET"])
def ping():
if auth_handler and auth_handler.get_token(request):
auth_handler.validate(request)
return "OK", 200
@@ -57,6 +65,9 @@ def after_request(response):
@app.route("/", methods=["POST"])
def upload():
if auth_handler:
auth_handler.validate(request)
results = []
for filename, file in request.files.items():
if not filename:
@@ -76,6 +87,9 @@ def upload():
@app.route("/<path:path>", methods=["GET"])
def download(path):
if auth_handler:
auth_handler.validate(request)
as_attachment = "download" in request.args
_, encoding = mimetypes.guess_type(os.path.basename(path))
@@ -105,18 +119,24 @@ def _get_full_path(path: str) -> Path:
@app.route("/<path:path>", methods=["DELETE"])
def delete(path):
real_path = _get_full_path(path)
if not real_path.exists() or not real_path.is_file():
log.error(f"Error deleting file {str(real_path)}. Not found or not a file")
abort(Response(f"File {str(real_path)} not found", 404))
if auth_handler:
auth_handler.validate(request)
real_path.unlink()
full_path = _get_full_path(path)
if not full_path.exists() or not full_path.is_file():
log.error(f"Error deleting file {str(full_path)}. Not found or not a file")
abort(Response(f"File {str(path)} not found", 404))
log.info(f"Deleted file {str(real_path)}")
full_path.unlink()
log.info(f"Deleted file {str(full_path)}")
return json.dumps(str(path)), 200
def batch_delete():
if auth_handler:
auth_handler.validate(request)
body = request.get_json(force=True, silent=False)
if not body:
abort(Response("Json payload is missing", 400))
@@ -139,17 +159,17 @@ def batch_delete():
record_error("Empty path not allowed", file, path)
continue
path = _get_full_path(path)
full_path = _get_full_path(path)
if not path.exists():
if not full_path.exists():
record_error("Not found", file, path)
continue
try:
if path.is_file():
path.unlink()
elif path.is_dir():
shutil.rmtree(path)
if full_path.is_file():
full_path.unlink()
elif full_path.is_dir():
shutil.rmtree(full_path)
else:
record_error("Not a file or folder", file, path)
continue
@@ -157,7 +177,7 @@ def batch_delete():
record_error(ex.strerror, file, path)
continue
except Exception as ex:
record_error(str(ex).replace(str(path), ""), file, path)
record_error(str(ex).replace(str(full_path), ""), file, path)
continue
deleted[file] = str(path)

103
fileserver/redis_manager.py Normal file
View File

@@ -0,0 +1,103 @@
from os import getenv
from boltons.iterutils import first
from redis import StrictRedis
from redis.cluster import RedisCluster
from config import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_REDIS_SERVICE_HOST",
"TRAINS_REDIS_SERVICE_HOST",
"REDIS_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PORT",
"TRAINS_REDIS_SERVICE_PORT",
"REDIS_SERVICE_PORT",
)
OVERRIDE_PASSWORD_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PASSWORD",
"TRAINS_REDIS_SERVICE_PASSWORD",
"REDIS_SERVICE_PASSWORD",
)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
log.info(f"Using override redis host {OVERRIDE_HOST}")
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override redis port {OVERRIDE_PORT}")
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
class ConfigError(Exception):
pass
class GeneralError(Exception):
pass
class RedisManager(object):
def __init__(self, redis_config_dict):
self.aliases = {}
for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict()
alias_config["password"] = config.get(
f"secure.redis.{alias}.password", None
)
is_cluster = alias_config.get("cluster", False)
host = OVERRIDE_HOST or alias_config.get("host", None)
if host:
alias_config["host"] = host
port = OVERRIDE_PORT or alias_config.get("port", None)
if port:
alias_config["port"] = port
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
if password:
alias_config["password"] = password
if not port or not host:
raise ConfigError(
f"Redis configuration is invalid. missing port or host (alias={alias})"
)
if is_cluster:
del alias_config["cluster"]
del alias_config["db"]
self.aliases[alias] = RedisCluster(**alias_config)
else:
self.aliases[alias] = StrictRedis(**alias_config)
def connection(self, alias) -> StrictRedis:
obj = self.aliases.get(alias)
if not obj:
raise GeneralError(f"Invalid Redis alias {alias}")
obj.get("health")
return obj
def host(self, alias):
r = self.connection(alias)
if isinstance(r, RedisCluster):
connections = r.get_default_node().redis_connection.connection_pool._available_connections
else:
connections = r.connection_pool._available_connections
if not connections:
return None
return connections[0].host
redman = RedisManager(config.get("hosts.redis"))

View File

@@ -1,9 +1,11 @@
boltons>=19.1.0
clearml-agent>=1.5.2
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=2.3.3
gunicorn>=20.1.0
pyhocon>=0.3.35
redis>=4.5.4,<5
setuptools>=65.5.1
urllib3>=1.26.18
werkzeug>=3.0.1