Compare commits

23 Commits

Author SHA1 Message Date
allegroai
08a7bc7c9f Fix not all the event logs are returned from sharded ES 2022-05-20 15:11:05 +03:00
allegroai
fb256d7e5b Version bump to v1.5 2022-05-18 15:29:45 +03:00
allegroai
710443b078 Fix move task to trash is not thread-safe 2022-05-18 10:31:20 +03:00
allegroai
e0cde2f7c9 Add support for deleting pipeline projects 2022-05-18 10:30:21 +03:00
allegroai
60b9c8de14 Allow arbitrary task fields in project statistics filter 2022-05-18 10:29:36 +03:00
allegroai
ecffe26be4 Fix auth.edit_credentials 2022-05-18 10:28:58 +03:00
allegroai
2570bd9e26 Fix ES issue with capital letters in index name 2022-05-18 10:18:23 +03:00
allegroai
174f84514a Fix no destination when merging projects 2022-05-18 10:17:34 +03:00
allegroai
65cb8d7b43 Refactor method name 2022-05-18 10:16:41 +03:00
allegroai
5f8ef808a3 Fix ES issue with capital letters in index name 2022-05-18 10:16:19 +03:00
allegroai
4941ac70e0 Add events.clear_task_log 2022-05-17 16:09:23 +03:00
allegroai
67cd461145 Add auth.edit_credentials 2022-05-17 16:08:12 +03:00
allegroai
92b5fc6f9a Fix handling hidden sub-projects 2022-05-17 16:06:34 +03:00
allegroai
b90165b4e4 Support queue_name in tasks enqueue 2022-05-17 16:04:34 +03:00
allegroai
6c2dcb5c8a Improve error message 2022-05-17 15:56:18 +03:00
allegroai
3efed32934 Add X-Jwt-Payload to redacted headers 2022-05-17 15:55:41 +03:00
allegroai
69737308fe Version bump to v1.4.0 2022-04-18 16:38:22 +03:00
allegroai
a6dbea808a Add indices for task.last_update and task.status_changed 2022-04-18 16:37:22 +03:00
allegroai
5131b17901 Support not returning hidden sub-projects when include_stats is specified without search_hidden 2022-04-18 16:36:14 +03:00
allegroai
5f21c3a56d Add support for searching hidden projects and tasks 2022-04-18 16:34:18 +03:00
allegroai
2350ac64ed Fix internal error on count task events if there is no events index 2022-04-18 16:31:02 +03:00
allegroai
d146127c18 Add events.clear_scroll endpoint to clear event search scrolls 2022-04-18 16:29:57 +03:00
Mal Miller
abd65e103e Ensure agent-services waits for API server to be ready (#129) 2022-03-31 11:10:45 +03:00
31 changed files with 493 additions and 118 deletions

View File

@@ -26,6 +26,9 @@
23: ["invalid_domain_name", "malformed domain name"]
24: ["not_public_object", "object is not public"]
# Auth / Login
75: ["invalid_access_key", "access key not found for user"]
# Tasks
100: ["task_error", "general task error"]
101: ["invalid_task_id", "invalid task id"]
@@ -86,7 +89,7 @@
# Database
800: ["data_validation_error", "data validation error"]
801: ["expected_unique_data", "value combination already exists"]
801: ["expected_unique_data", "value combination already exists (unique field already contains this value)"]
# Workers
1001: ["invalid_worker_id", "invalid worker id"]

View File

@@ -96,6 +96,11 @@ class GetCredentialsResponse(Base):
credentials = ListField(CredentialsResponse)
class EditCredentialsRequest(Base):
access_key = StringField(required=True)
label = StringField()
class RevokeCredentialsRequest(Base):
access_key = StringField(required=True)

View File

@@ -137,3 +137,13 @@ class TaskPlotsRequest(Base):
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ClearScrollRequest(Base):
scroll_id: str = StringField()
class ClearTaskLogRequest(Base):
task: str = StringField(required=True)
threshold_sec = IntField()
allow_locked = BoolField(default=False)

View File

@@ -65,3 +65,4 @@ class ProjectsGetRequest(models.Base):
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)

View File

@@ -96,6 +96,7 @@ class UpdateRequest(TaskUpdateRequest):
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
class DeleteRequest(UpdateRequest):
@@ -262,6 +263,7 @@ class StopManyRequest(TaskBatchRequest):
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
queue_name = StringField()
validate_tasks = BoolField(default=False)

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
from elasticsearch import helpers
import elasticsearch
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
@@ -48,6 +48,9 @@ MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
log = config.logger(__file__)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
@@ -219,7 +222,7 @@ class EventBLL(object):
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
@@ -962,18 +965,23 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
@staticmethod
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
@@ -987,6 +995,51 @@ class EventBLL(object):
return es_res.get("deleted", 0)
def clear_task_log(
self,
company_id: str,
task_id: str,
allow_locked: bool = False,
threshold_sec: int = None,
):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.task_log
):
return 0
with translate_errors_context(), TimingContext("es", "clear_task_log"):
must = [{"term": {"task": task_id}}]
sort = None
if threshold_sec:
timestamp_ms = int(threshold_sec * 1000)
must.append(
{
"range": {
"timestamp": {
"lt": (
es_factory.get_timestamp_millis() - timestamp_ms
)
}
}
}
)
sort = {"timestamp": {"order": "desc"}}
es_req = {
"query": {"bool": {"must": must}},
**({"sort": sort} if sort else {}),
}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.task_log,
body=es_req,
refresh=True,
)
return es_res.get("deleted", 0)
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
"""
Delete mutliple task events. No check is done for tasks write access
@@ -1005,3 +1058,16 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:
return
# noinspection PyBroadException
try:
self.es.clear_scroll(scroll_id=scroll_id)
except elasticsearch.exceptions.NotFoundError:
pass
except elasticsearch.exceptions.RequestError:
pass
except Exception as ex:
log.exception("Failed clearing scroll %s", scroll_id)

View File

@@ -41,7 +41,7 @@ class EventSettings:
def get_index_name(company_id: str, event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
return f"events-{event_type}-{company_id.lower()}"
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:

View File

@@ -400,7 +400,7 @@ class EventMetrics:
return es_res.get("aggregations")
def get_tasks_metrics(
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""

View File

@@ -67,6 +67,9 @@ class EventsIterator:
task_id: str,
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
es_req = {
"query": query,
@@ -78,7 +81,6 @@ class EventsIterator:
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
return es_result["count"]
@@ -119,7 +121,6 @@ class EventsIterator:
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
@@ -143,7 +144,6 @@ class EventsIterator:
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:

View File

@@ -6,6 +6,7 @@ from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model import EntityVisibility
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -42,6 +43,8 @@ class _TagsCache:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project__in=project_ids_with_children([project]))
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)

View File

@@ -17,6 +17,7 @@ from typing import (
Any,
)
from boltons.iterutils import partition
from mongoengine import Q, Document
from apiserver import database
@@ -62,20 +63,24 @@ class ProjectBLL:
source=source_id
)
source = Project.get(company, source_id)
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
if destination_id:
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
else:
destination = None
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
if destination:
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
moved_entities = 0
for entity_type in (Task, Model):
@@ -146,10 +151,8 @@ class ProjectBLL:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
if (
new_parent
and project_id == new_parent.id
or project_id in new_parent.path
if new_parent and (
project_id == new_parent.id or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
@@ -511,13 +514,16 @@ class ProjectBLL:
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
search_hidden: bool = False,
filter_: Mapping[str, Any] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = (
_get_sub_projects(project_ids, _only=("id", "name"))
_get_sub_projects(
project_ids, _only=("id", "name"), search_hidden=search_hidden
)
if include_children
else {}
)
@@ -740,10 +746,13 @@ class ProjectBLL:
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
@@ -772,7 +781,8 @@ class ProjectBLL:
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@@ -799,17 +809,20 @@ class ProjectBLL:
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
for field, field_filter in filter_.items():
if not (
field_filter
and isinstance(field_filter, list)
and all(isinstance(t, str) for t in field_filter)
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
conditions[field] = {
**({"$in": include} if include else {}),
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
}
return conditions

View File

@@ -13,7 +13,7 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from apiserver.timing_context import TimingContext
from .sub_projects import _ids_with_children
@@ -32,22 +32,28 @@ class DeleteProjectResult:
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
).count()
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in (Task, Model):
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).count()
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
name = f"non_archived_{cls.__name__.lower()}s"
if not is_pipeline:
ret[name] = cls.objects(**query).count()
else:
ret[name] = (
cls.objects(**query, type=TaskType.controller).count()
if cls == Task
else 0
)
return ret
@@ -56,23 +62,30 @@ def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
if not force:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
if not is_pipeline:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(**query).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
else:
non_archived = Task.objects(**query, type=TaskType.controller).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
raise errors.bad_request.ProjectHasTasks(
"please archive all the runs inside the project", id=project_id
)
if not delete_contents:
with TimingContext("mongo", "update_children"):

View File

@@ -4,6 +4,7 @@ from typing import Tuple, Optional, Sequence, Mapping
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
name_separator = "/"
@@ -100,12 +101,17 @@ def _get_writable_project_from_name(
def _get_sub_projects(
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
project_ids: Sequence[str],
_only: Sequence[str] = ("id", "path"),
search_hidden=True,
) -> Mapping[str, Sequence[Project]]:
"""
Return the list of child projects of all the levels for the parent project ids
"""
qs = Project.objects(path__in=project_ids)
query = dict(path__in=project_ids)
if not search_hidden:
query["system_tags__nin"] = [EntityVisibility.hidden.value]
qs = Project.objects(**query)
if _only:
_only = set(_only) | {"path"}
qs = qs.only(*_only)

View File

@@ -50,6 +50,18 @@ class QueueBLL(object):
queue.save()
return queue
def get_by_name(
self,
company_id: str,
queue_name: str,
only: Optional[Sequence[str]] = None,
) -> Queue:
qs = Queue.objects(name=queue_name, company=company_id)
if only:
qs = qs.only(*only)
return qs.first()
def get_by_id(
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
) -> Queue:

View File

@@ -29,7 +29,7 @@ class QueueMetrics:
@staticmethod
def _queue_metrics_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"queue_metrics_{company_id}_"
return f"queue_metrics_{company_id.lower()}_"
@staticmethod
def _get_es_index_suffix():

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Callable, Any, Tuple, Union
from typing import Callable, Any, Tuple, Union, Sequence
from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL
@@ -25,6 +25,7 @@ from apiserver.database.model.task.task import (
)
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
queue_bll = QueueBLL()
@@ -83,10 +84,7 @@ def unarchive_task(
def dequeue_task(
task_id: str,
company_id: str,
status_message: str,
status_reason: str,
task_id: str, company_id: str, status_message: str, status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
@@ -94,10 +92,7 @@ def dequeue_task(
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=status_message,
status_reason=status_reason,
task, company_id, status_message=status_message, status_reason=status_reason,
)
return 1, res
@@ -108,9 +103,23 @@ def enqueue_task(
queue_id: str,
status_message: str,
status_reason: str,
queue_name: str = None,
validate: bool = False,
force: bool = False,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
)
if not queue:
queue = queue_bll.create(company_id=company_id, name=queue_name)
queue_id = queue.id
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
@@ -155,6 +164,30 @@ def enqueue_task(
return 1, res
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
try:
collection_name = Task._get_collection_name()
trash_collection_name = f"{collection_name}__trash"
Task.aggregate(
[
{"$match": {"_id": {"$in": tasks}}},
{
"$merge": {
"into": trash_collection_name,
"on": "_id",
"whenMatched": "replace",
"whenNotMatched": "insert",
}
},
],
allow_disk_use=True,
)
except Exception as ex:
log.error(f"Error copying tasks to trash {str(ex)}")
return Task.objects(id__in=tasks).delete()
def delete_task(
task_id: str,
company_id: str,
@@ -200,18 +233,12 @@ def delete_task(
)
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.save()
else:
task.delete()
task.delete()
update_project_time(task.project)
return 1, task, cleanup_res

View File

@@ -20,7 +20,7 @@ class WorkerStats:
@staticmethod
def worker_stats_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"worker_stats_{company_id}_"
return f"worker_stats_{company_id.lower()}_"
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(

View File

@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
class EntityVisibility(Enum):
active = "active"
archived = "archived"
hidden = "hidden"

View File

@@ -175,6 +175,8 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"last_update",
"status_changed",
"models.input.model",
("company", "name"),
("company", "user"),

View File

@@ -262,6 +262,38 @@ get_credentials {
}
}
edit_credentials {
allow_roles = [ "*" ]
internal: false
"2.19" {
description: """Updates the label of the existing credentials for the authenticated user."""
request {
type: object
required: [ access_key ]
properties {
access_key {
type: string
description: Existing credentials key
}
label {
type: string
description: New credentials label
}
}
}
response {
type: object
properties {
updated {
description: "Number of credentials updated"
type: integer
enum: [0, 1]
}
}
}
}
}
revoke_credentials {
allow_roles = [ "*" ]
internal: false

View File

@@ -1304,3 +1304,58 @@ scalar_metrics_iter_raw {
}
}
}
clear_scroll {
"2.18" {
description: "Clear an open Scroll ID"
request {
type: object
required: [
scroll_id
]
properties {
scroll_id {
description: "Scroll ID as returned by previous events service calls"
type: string
}
}
}
response {
type: object
additionalProperties: false
}
}
}
clear_task_log {
"2.19" {
description: Remove old logs from task
request {
type: object
required: [task]
properties {
task {
description: Task ID
type: string
}
allow_locked {
type: boolean
description: Allow deleting events even if the task is locked
default: false
}
threshold_sec {
description: The amount of seconds ago to retain the log records. The older log records will be deleted. If not passed or 0 then all the log records for the task will be deleted
type: integer
}
}
}
response {
type: object
properties {
deleted {
description: The number of deleted log records
type: integer
}
}
}
}
}

View File

@@ -455,7 +455,14 @@ get_all {
}
}
}
"2.15": ${get_all."2.13"} {
"2.14": ${get_all."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {
@@ -536,7 +543,14 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
@@ -568,15 +582,9 @@ get_all_ex {
}
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
description: The filter for selecting entities that participate in statistics calculation. For each task field that you want to filter on pass the list of allowed values. Prepend the value with '-' to exclude
type: object
properties {
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
additionalProperties: true
}
}
}

View File

@@ -685,7 +685,14 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
@@ -822,7 +829,14 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
"2.14": ${get_all."2.1"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {
@@ -1884,7 +1898,7 @@ Fails if the following parameters in the task were not filled:
]
properties {
queue {
description: "Queue id. If not provided, task is added to the default queue."
description: "Queue id. If not provided and no queue name is passed then task is added to the default queue."
type: string
}
}
@@ -1900,6 +1914,12 @@ Fails if the following parameters in the task were not filled:
}
}
}
"2.19": ${enqueue."1.5"} {
request.properties.queue_name {
description: The name of the queue. If the queue does not exist then it is auto-created. Cannot be used together with the queue id
type: string
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {
@@ -1908,7 +1928,7 @@ enqueue_many {
properties {
ids.description: "IDs of the tasks to enqueue"
queue {
description: "Queue id. If not provided, tasks are added to the default queue."
description: "Queue id. If not provided and no queue name is passed then tasks are added to the default queue."
type: string
}
validate_tasks {
@@ -1927,6 +1947,12 @@ enqueue_many {
}
}
}
"2.19": ${enqueue_many."2.13"} {
request.properties.queue_name {
description: The name of the queue. If the queue does not exist then it is auto-created. Cannot be used together with the queue id
type: string
}
}
}
dequeue {
"1.5" {

View File

@@ -313,6 +313,7 @@ class APICall(DataContainer):
_redacted_headers = {
HEADER_AUTHORIZATION: " ",
"Cookie": "=",
"X-Jwt-Payload": "",
}
""" Headers whose value should be redacted. Maps header name to partition char """
@@ -692,6 +693,10 @@ class APICall(DataContainer):
# this will allow us to debug authorization errors).
for header, sep in self._redacted_headers.items():
if header in headers:
prefix, _, redact = headers[header].partition(sep)
if sep:
prefix, _, redact = headers[header].partition(sep)
else:
prefix = sep = ""
redact = headers[header]
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
return headers

View File

@@ -14,6 +14,7 @@ from apiserver.apimodels.auth import (
RevokeCredentialsRequest,
EditUserReq,
CreateCredentialsRequest,
EditCredentialsRequest,
)
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.auth import AuthBLL
@@ -122,6 +123,27 @@ def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@endpoint("auth.edit_credentials")
def edit_credentials(call: APICall, company_id: str, request: EditCredentialsRequest):
identity = call.identity
access_key = request.access_key
updated = User.objects(
id=identity.user,
company=company_id,
credentials__match={"key": access_key},
).update_one(set__credentials__S__label=request.label)
if not updated:
raise errors.bad_request.InvalidAccessKey(
"invalid user or invalid access key",
user=identity.user,
access_key=access_key,
company=company_id,
)
call.result.data = {"updated": updated}
@endpoint(
"auth.revoke_credentials",
request_data_model=RevokeCredentialsRequest,

View File

@@ -25,6 +25,8 @@ from apiserver.apimodels.events import (
TaskPlotsRequest,
TaskEventsRequest,
ScalarMetricsIterRawRequest,
ClearScrollRequest,
ClearTaskLogRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
@@ -768,14 +770,14 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id,
task_ids=request.tasks,
allow_public=True,
only=("company", "company_origin"),
)[0]
res = event_bll.metrics.get_tasks_metrics(
res = event_bll.metrics.get_task_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
)
call.result.data = {
@@ -796,6 +798,21 @@ def delete_for_task(call, company_id, req_model):
)
@endpoint("events.clear_task_log")
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
task_bll.assert_exists(company_id, task_id, return_tasks=False)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,
task_id=task_id,
allow_locked=request.allow_locked,
threshold_sec=request.threshold_sec,
)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
key = itemgetter("metric", "variant", "task", "iter")
@@ -936,3 +953,9 @@ def scalar_metrics_iter_raw(
scroll_id=scroll.get_scroll_id(),
variants=variants,
)
@endpoint("events.clear_scroll", min_version="2.18")
def clear_scroll(_, __, request: ClearScrollRequest):
if request.scroll_id:
event_bll.clear_scroll(request.scroll_id)

View File

@@ -26,6 +26,7 @@ from apiserver.bll.project.project_cleanup import (
validate_project_delete,
)
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.utils import (
parse_from_call,
@@ -73,6 +74,16 @@ def get_by_id(call):
call.result.data = {"project": project_dict}
def _hidden_query(search_hidden: bool, ids: Sequence) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
"""
if search_hidden or ids:
return Q()
return Q(system_tags__ne=EntityVisibility.hidden.value)
def _adjust_search_parameters(data: dict, shallow_search: bool):
"""
1. Make sure that there is no external query on path
@@ -91,12 +102,14 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data)
allow_public = not request.non_public
data = call.data
conform_tag_fields(call, data)
allow_public = not request.non_public
requested_ids = data.get("id")
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
)
with TimingContext("mongo", "projects_get_all"):
data = call.data
if request.active_users:
ids = project_bll.get_projects_with_active_user(
company=company_id,
@@ -105,16 +118,14 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
allow_public=allow_public,
)
if not ids:
call.result.data = {"projects": []}
return
return {"projects": []}
data["id"] = ids
_adjust_search_parameters(data, shallow_search=request.shallow_search)
ret_params = {}
projects = Project.get_many_with_join(
projects: Sequence[dict] = Project.get_many_with_join(
company=company_id,
query_dict=data,
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
allow_public=allow_public,
ret_params=ret_params,
)
@@ -143,6 +154,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=request.search_hidden,
filter_=request.include_stats_filter,
)
@@ -155,20 +167,24 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
@endpoint("projects.get_all")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
data = call.data
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
conform_tag_fields(call, data)
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
)
with TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
),
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects, **ret_params}

View File

@@ -94,10 +94,12 @@ from apiserver.bll.task.task_operations import (
delete_task,
publish_task,
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
@@ -213,6 +215,16 @@ def _process_include_subprojects(call_data: dict):
call_data["project"] = project_ids_with_children(project_ids)
def _hidden_query(data: dict) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
"""
if data.get("search_hidden") or data.get("id"):
return Q()
return Q(system_tags__ne=EntityVisibility.hidden.value)
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
@@ -225,6 +237,7 @@ def get_all_ex(call: APICall, company_id, _):
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
@@ -259,6 +272,7 @@ def get_all(call: APICall, company_id, _):
company=company_id,
parameters=call_data,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
@@ -848,6 +862,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
queue_name=request.queue_name,
force=request.force,
)
call.result.data_model = EnqueueResponse(queued=queued, **res)
@@ -866,6 +881,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
queue_name=request.queue_name,
validate=request.validate_tasks,
),
ids=request.ids,
@@ -1060,6 +1076,8 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_reason=request.status_reason,
)
if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@@ -1081,6 +1099,10 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
)
if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))

View File

@@ -20,7 +20,7 @@ class TestBatchOperations(TestService):
ids = [*tasks, missing_id]
# enqueue
res = self.api.tasks.enqueue_many(ids=ids)
res = self.api.tasks.enqueue_many(ids=ids, queue_name="test")
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks

View File

@@ -1 +1 @@
__version__ = "1.3.0"
__version__ = "1.5.0"

View File

@@ -154,6 +154,8 @@ services:
- /opt/clearml/agent:/root/.clearml
depends_on:
- apiserver
entrypoint: >
bash -c "curl --retry 10 --retry-delay 10 --retry-connrefused 'http://apiserver:8008/debug.ping' && /usr/agent/entrypoint.sh"
networks:
backend: