Compare commits

22 Commits

Author SHA1 Message Date
allegroai
11d76e7d8c Update AWS AMIs for v0.15.0 2020-06-01 23:07:38 +03:00
allegroai
e76c0fbc63 Version bump to 0.15.0 2020-06-01 22:20:58 +03:00
allegroai
fdc9956da3 Update trains-agent-services docker image 2020-06-01 21:53:33 +03:00
allegroai
f4addaa653 Add new services mode agent container to the docker-compose 2020-06-01 21:02:49 +03:00
allegroai
667964cc82 Add clear_all flag to tasks.reset 2020-06-01 13:07:35 +03:00
allegroai
e1309e30b7 Fix UPLOAD_FOLDER handling when provided as env var or when fileserver is run by gunicorn 2020-06-01 13:05:45 +03:00
allegroai
9403942ef7 Add support for additional task types as well as tasks.get_types to obtain actual types used globally or per project 2020-06-01 13:05:12 +03:00
allegroai
84a75d9e70 Add server uid to server.info response in API v2.8 2020-06-01 13:01:31 +03:00
allegroai
c85ab66ae6 Add organization.get_tags to obtain the set of all used task, model, queue and project tags 2020-06-01 13:00:35 +03:00
allegroai
bf7f0f646b Sort hyper parameters numeric values as numbers and not strings 2020-06-01 12:27:56 +03:00
allegroai
dcdf2a3d58 Fix task can't be cloned if input model was deleted 2020-06-01 12:23:29 +03:00
allegroai
f8d8fc40a6 Support filtering users by activity in projects 2020-06-01 11:55:40 +03:00
allegroai
45d434a123 When clearing a task do not delete draft models used by other tasks 2020-06-01 11:51:43 +03:00
allegroai
1834abe5bc Better handling of execution parameter paths 2020-06-01 11:49:35 +03:00
allegroai
d6321588f3 Fix role checked for endpoints not requiring authorization 2020-06-01 11:43:55 +03:00
allegroai
c17b10ff1d Revoke built-in webserver system-role credentials (used by the WebApp) in case we're running in fixed-mode 2020-06-01 11:41:43 +03:00
allegroai
b125a56f86 Make sure configuration path loaded from an environment variable name is lower-case 2020-06-01 11:40:34 +03:00
allegroai
c43ce3a17b Update 0.15 mongo migration to drop indices (so new ones will be automatically created) 2020-06-01 11:36:22 +03:00
allegroai
b0b09616a8 Fix single bad event causes events.add_batch to skip remaining events 2020-06-01 11:33:39 +03:00
allegroai
ede5586ccc Extract non-responsive tasks watchdog from main tasks logic 2020-06-01 11:31:36 +03:00
allegroai
a1dcdffa53 Update pymongo and mongoengine versions 2020-06-01 11:29:50 +03:00
allegroai
35a11db58e Support task log retrieval with no scroll 2020-06-01 11:27:36 +03:00
72 changed files with 1854 additions and 646 deletions

View File

@@ -22,6 +22,8 @@ services:
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
ports:
- "8008:8008"
networks:

View File

@@ -115,6 +115,36 @@ services:
ports:
- "8080:80"
agent-services:
container_name: trains-agent-services
image: allegroai/trains-agent-services:latest
restart: unless-stopped
privileged: true
environment:
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
TRAINS_API_HOST: ${TRAINS_API_HOST:-}
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
TRAINS_WORKER_ID: "trains-services"
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/trains/agent:/root/.trains
depends_on:
- apiserver
networks:
backend:
driver: bridge

View File

@@ -50,26 +50,45 @@ To upgrade the AMI:
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
### Latest version AMI - v0.14.2 (auto update)<a name="autoupdate"></a>
### Latest version AMI - v0.15.0 (auto update)<a name="autoupdate"></a>
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
* **eu-north-1** : ami-095cc888970c06e09
* **ap-south-1** : ami-07019e7b3febea37e
* **eu-west-3** : ami-0433d76badf430c16
* **eu-west-2** : ami-05794c2b23ff79990
* **eu-west-1** : ami-03e3bcabd1863d666
* **ap-northeast-2** : ami-00f14188b66a5803e
* **ap-northeast-1** : ami-005c93e30c99dab0c
* **sa-east-1** : ami-0d819231779e7d264
* **ca-central-1** : ami-0eff2fd400939d960
* **ap-southeast-1** : ami-049b21bfa0d35c21c
* **ap-southeast-2** : ami-0318b96a72d5da068
* **eu-central-1** : ami-0cdb9d794340b9704
* **us-east-2** : ami-0d846a080fc5a9345
* **us-west-1** : ami-0ef970342625159bf
* **us-west-2** : ami-04f3d13b75c642506
* **us-east-1** : ami-01bef4da91280a322
* **eu-north-1** : ami-0a05eb5b384a84609
* **ap-south-1** : ami-00f190b50e60b1eb5
* **eu-west-3** : ami-044fad585e1d1798e
* **eu-west-2** : ami-04ab930416a4af8c5
* **eu-west-1** : ami-00c022f333417e78e
* **ap-northeast-2** : ami-0c436e94f461a9a22
* **ap-northeast-1** : ami-018e761ad0009d5d4
* **sa-east-1** : ami-0b6c0e8e93b6ebbdd
* **ca-central-1** : ami-0cf12aab70c14237d
* **ap-southeast-1** : ami-0fe7840b9bde05581
* **ap-southeast-2** : ami-00f230e86e1afda91
* **eu-central-1** : ami-0635d13b79f76e04f
* **us-east-2** : ami-0b323078d0206db0e
* **us-west-1** : ami-07fdc1d461906f957
* **us-west-2** : ami-0a5cac167c3ebdedb
* **us-east-1** : ami-0d03956bea3aa5a44
### v0.15.0 (static update)
* **eu-north-1** : ami-0475a5068d615769b
* **ap-south-1** : ami-00c7e642badaa2ebf
* **eu-west-3** : ami-0655f769c28843e25
* **eu-west-2** : ami-04d82f48f09e2b846
* **eu-west-1** : ami-07a2aab2dc7b4ec5f
* **ap-northeast-2** : ami-0257ab220a8bc7a52
* **ap-northeast-1** : ami-0c4900af758b91dde
* **sa-east-1** : ami-021f758a4a21d5725
* **ca-central-1** : ami-0ce9703b3b47cfe70
* **ap-southeast-1** : ami-0b38689fdb8f71b74
* **ap-southeast-2** : ami-0c2b3a171e7ae4b00
* **eu-central-1** : ami-0fdd3420d6e6b4a1f
* **us-east-2** : ami-0288e9654da36ed1c
* **us-west-1** : ami-0f1d6ee0b73fe9ca2
* **us-west-2** : ami-025f0c5bfeacbf390
* **us-east-1** : ami-0b17b0bfa8b91f805
### v0.14.2 (static update)

View File

@@ -10,12 +10,13 @@ from flask_cors import CORS
from config import config
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
app = Flask(__name__)
CORS(app, **config.get("fileserver.cors"))
Compress(app)
if os.environ.get("TRAINS_UPLOAD_FOLDER"):
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER")
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
@app.route("/", methods=["POST"])
@@ -57,12 +58,13 @@ def main():
parser.add_argument(
"--upload-folder",
"-u",
default="/mnt/fileserver",
default=DEFAULT_UPLOAD_FOLDER,
help="Upload folder (default %(default)s)",
)
args = parser.parse_args()
app.config["UPLOAD_FOLDER"] = args.upload_folder
if app.config.get("UPLOAD_FOLDER") is None:
app.config["UPLOAD_FOLDER"] = args.upload_folder
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)

View File

@@ -1 +1 @@
__version__ = "2.7.0"
__version__ = "2.8.0"

View File

@@ -47,6 +47,7 @@ _error_codes = {
128: ('invalid_task_output', 'invalid task output'),
129: ('task_publish_in_progress', 'Task publish in progress'),
130: ('task_not_found', 'task not found'),
131: ('events_not_added', 'events not added'),
# Models
200: ('model_error', 'general task error'),

View File

@@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError
from validators import email as email_validator, domain as domain_validator
from apierrors import errors
from utilities.json import loads, dumps
def make_default(field_cls, default_value):
@@ -206,10 +207,10 @@ class DomainField(fields.StringField):
raise errors.bad_request.InvalidDomainName()
class StringEnum(Enum):
def __str__(self):
return self.value
class JsonSerializableMixin:
def to_json(self: ModelBase):
return dumps(self.to_struct())
# noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values):
return name
@classmethod
def from_json(cls: Type[ModelBase], s):
return cls(**loads(s))

View File

@@ -40,6 +40,14 @@ class DebugImagesRequest(Base):
scroll_id: str = StringField()
class LogEventsRequest(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)

View File

@@ -0,0 +1,10 @@
from jsonmodels import fields, models
class Filter(models.Base):
system_tags = fields.ListField([str])
class TagsRequest(models.Base):
include_system = fields.BoolField(default=False)
filter = fields.EmbeddedField(Filter)

View File

@@ -92,6 +92,10 @@ class PingRequest(TaskRequest):
pass
class GetTypesRequest(models.Base):
projects = ListField(items_types=[str])
class CloneRequest(TaskRequest):
new_task_name = StringField()
new_task_comment = StringField()
@@ -100,6 +104,7 @@ class CloneRequest(TaskRequest):
new_task_parent = StringField()
new_task_project = StringField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
class AddOrUpdateArtifactsRequest(TaskRequest):
@@ -109,3 +114,7 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsResponse(models.Base):
added = ListField([str])
updated = ListField([str])
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)

View File

@@ -1,4 +1,3 @@
import json
from enum import Enum
import six
@@ -13,7 +12,7 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apimodels import make_default, ListField, EnumField
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
@@ -61,7 +60,7 @@ class IdNameEntry(Base):
name = StringField()
class WorkerEntry(Base):
class WorkerEntry(Base, JsonSerializableMixin):
key = StringField() # not required due to migration issues
id = StringField(required=True)
user = EmbeddedField(IdNameEntry)
@@ -75,13 +74,6 @@ class WorkerEntry(Base):
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
def to_json(self):
return json.dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**json.loads(s))
class CurrentTaskEntry(IdNameEntry):
running_time = IntField()

View File

@@ -3,27 +3,25 @@ from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from typing import Sequence, Tuple, Optional, Mapping
import database
from apierrors import errors
from bll.redis_cache_manager import RedisCacheManager
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from jsonmodels.models import Base
from jsonmodels.fields import StringField, ListField, IntField
from database.model.task.metrics import MetricEventStats
from database.model.task.task import Task
from timing_context import TimingContext
from utilities.json import loads, dumps
class VariantScrollState(Base):
@@ -45,17 +43,10 @@ class MetricScrollState(Base):
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base):
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
def to_json(self):
return dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**loads(s))
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
@@ -65,7 +56,12 @@ class DebugImagesResult(object):
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
STATE_EXPIRATION_SECONDS = 3600
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@property
def _max_workers(self):
@@ -76,7 +72,7 @@ class DebugImagesIterator:
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=self.STATE_EXPIRATION_SECONDS,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
@@ -92,27 +88,31 @@ class DebugImagesIterator:
if not self.es.indices.exists(es_index):
return DebugImagesResult()
unique_metrics = set(metrics)
state = self.cache_manager.get_state(state_id) if state_id else None
if not state:
state = DebugImageEventsScrollState(
id=database.utils.id(),
metrics=self._init_metric_states(es_index, list(unique_metrics)),
)
else:
state_metrics = set((m.task, m.name) for m in state.metrics)
if state_metrics != unique_metrics:
raise errors.bad_request.InvalidScrollId(
"while getting debug images events", scroll_id=state_id
)
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
def validate_state(state_: DebugImageEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state)
for metric_state in state.metrics:
self._reinit_outdated_metric_states(company_id, es_index, state_)
for metric_state in state_.metrics:
metric_state.reset()
res = DebugImagesResult(next_scroll_id=state.id)
try:
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(self._max_workers) as pool:
res.metric_events = list(
pool.map(
@@ -125,10 +125,8 @@ class DebugImagesIterator:
state.metrics,
)
)
finally:
self.cache_manager.set_state(state)
return res
return res
def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState

View File

@@ -3,9 +3,8 @@ from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence
from typing import Sequence, Set, Tuple
import attr
import six
from elasticsearch import helpers
from mongoengine import Q
@@ -16,6 +15,7 @@ import es_factory
from apierrors import errors
from bll.event.debug_images_iterator import DebugImagesIterator
from bll.event.event_metrics import EventMetrics, EventType
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
@@ -29,13 +29,6 @@ EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@attr.s(auto_attribs=True)
class TaskEventsResult(object):
total_events: int = 0
next_scroll_id: str = None
events: list = attr.ib(factory=list)
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
@@ -47,12 +40,28 @@ class EventBLL(object):
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
@property
def metrics(self) -> EventMetrics:
return self._metrics
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
@staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
@@ -62,19 +71,34 @@ class EventBLL(object):
task_last_events = nested_dict(
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked_tasks,
)
for event in events:
# remove spaces from event type
if "type" not in event:
raise errors.BadRequest("Event must have a 'type' field", event=event)
event_type = event.get("type")
if event_type is None:
errors_per_type["Event must have a 'type' field"] += 1
continue
event_type = event["type"].replace(" ", "_")
event_type = event_type.replace(" ", "_")
if event_type not in EVENT_TYPES:
raise errors.BadRequest(
"Invalid event type {}".format(event_type),
event=event,
types=EVENT_TYPES,
)
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
task_id = event.get("task")
if task_id is None:
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_id not in valid_tasks:
errors_per_type["Invalid task id"] += 1
continue
event["type"] = event_type
@@ -120,89 +144,75 @@ class EventBLL(object):
else:
es_action["_id"] = dbutils.id()
task_id = event.get("task")
if task_id is not None:
es_action["_routing"] = task_id
task_ids.add(task_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
es_action["_routing"] = task_id
task_ids.add(task_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
else:
es_action["_routing"] = task_id
actions.append(es_action)
if task_ids:
# verify task_ids
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
extra_msg = None
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id")
if len(res) < len(task_ids):
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, ids=invalid_task_ids
added = 0
if actions:
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
errors_in_bulk = []
added = 0
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_in_bulk.append(info)
if not updated:
remaining_tasks.add(task_id)
continue
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
# who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
added = min(added, len(actions))
return added, errors_in_bulk
if not added:
raise errors.bad_request.EventsNotAdded(**errors_per_type)
errors_count = sum(errors_per_type.values())
return added, errors_count, errors_per_type
def _update_last_scalar_events_for_task(self, last_events, event):
"""

View File

@@ -0,0 +1,169 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apierrors import errors
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from timing_context import TimingContext
class LogEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
task: str = StringField(required=True)
last_min_timestamp: Optional[int] = IntField()
last_max_timestamp: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state """
self.last_min_timestamp = self.last_max_timestamp = None
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = "log"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=LogEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
def init_state(state_: LogEventsScrollState):
state_.task = task_id
def validate_state(state_: LogEventsScrollState):
"""
Checks that the task id stored in the state
is equal to the one passed with the current call
Refresh the state if requested
"""
if state_.task != task_id:
raise errors.bad_request.InvalidScrollId(
"Task stored in the state does not match the passed one",
scroll_id=state_.id,
)
if refresh:
state_.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res = TaskEventsResult(next_scroll_id=state.id)
res.events, res.total_events = self._get_events(
es_index=es_index,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
state=state,
)
return res
def _get_events(
self,
es_index,
batch_size: int,
navigate_earlier: bool,
state: LogEventsScrollState,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": state.task}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if navigate_earlier and state.last_min_timestamp is not None:
es_req["search_after"] = [state.last_min_timestamp]
elif not navigate_earlier and state.last_max_timestamp is not None:
es_req["search_after"] = [state.last_max_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
if navigate_earlier:
state.last_max_timestamp = events[0]["timestamp"]
state.last_min_timestamp = events[-1]["timestamp"]
else:
state.last_min_timestamp = events[0]["timestamp"]
state.last_max_timestamp = events[-1]["timestamp"]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": state.task}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
already_present_ids = set(ev["_id"] for ev in events)
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[
*events,
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
],
hits_total,
)

View File

@@ -4,7 +4,7 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from apimodels import StringEnum
from utilities.stringenum import StringEnum
from bll.util import extract_properties_to_lists
from config import config

View File

@@ -0,0 +1,85 @@
from typing import Sequence
from mongoengine import Q
from config import config
from database.model.base import GetMixin
from database.model.model import Model
from database.model.task.task import Task
from redis_manager import redman
from utilities import json
log = config.logger(__file__)
class OrgBLL:
_tags_field = "tags"
_system_tags_field = "system_tags"
_settings_prefix = "services.organization"
def __init__(self, redis=None):
self.redis = redis or redman.connection("apiserver")
@property
def _tags_cache_expiration_seconds(self):
return config.get(
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
)
@staticmethod
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
filter_str = "_".join(filter_) if filter_ else ""
return f"{field}_{company}_{filter_str}"
@staticmethod
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
query = Q(company=company)
if filter_:
query &= GetMixin.get_list_field_query("system_tags", filter_)
tags = set()
for cls_ in (Task, Model):
tags |= set(cls_.objects(query).distinct(field))
return tags
def get_tags(
self, company, include_system: bool = False, filter_: Sequence[str] = None
) -> dict:
"""
Get tags and optionally system tags for the company
Return the dictionary of tags per tags field name
The function retrieves both cached values from Redis in one call
and re calculates any of them if missing in Redis
"""
fields = [
self._tags_field,
*([self._system_tags_field] if include_system else []),
]
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
cached = self.redis.mget(redis_keys)
ret = {}
for field, tag_data, key in zip(fields, cached, redis_keys):
if tag_data is not None:
tags = json.loads(tag_data)
else:
tags = list(self._get_tags_from_db(company, field, filter_))
self.redis.setex(
key,
time=self._tags_cache_expiration_seconds,
value=json.dumps(tags),
)
ret[field] = tags
return ret
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
"""
Updates system tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
"""
if reset or tags is not None:
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
if reset or system_tags is not None:
self.redis.delete(
self._get_tags_cache_key(company, self._system_tags_field)
)

View File

@@ -0,0 +1 @@
from .project_bll import ProjectBLL

View File

@@ -0,0 +1,33 @@
from typing import Sequence, Optional
from mongoengine import Q
from config import config
from database.model.model import Model
from database.model.task.task import Task
from timing_context import TimingContext
log = config.logger(__file__)
class ProjectBLL:
@classmethod
def get_active_users(
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
) -> set:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if project_ids:
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res

View File

@@ -1,15 +1,21 @@
from typing import Optional, TypeVar, Generic, Type
from contextlib import contextmanager
from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
import database
from timing_context import TimingContext
T = TypeVar("T")
def _do_nothing(_: T):
return
class RedisCacheManager(Generic[T]):
"""
Class for store/retreive of state objects from redis
Class for store/retrieve of state objects from redis
self.state_class - class of the state
self.redis - instance of redis
@@ -42,3 +48,32 @@ class RedisCacheManager(Generic[T]):
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
@contextmanager
def get_or_create_state(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
):
"""
Try to retrieve state with the given id from the Redis cache if yes then validates it
If no then create a new one with randomly generated id
Yield the state and write it back to redis once the user code block exits
:param state_id: id of the state to retrieve
:param init_state: user callback to init the newly created state
If not passed then no init except for the id generation is done
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
try:
yield state
finally:
self.set_state(state)

View File

@@ -280,7 +280,7 @@ class StatisticsReporter:
]
return {
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
for group in Task.aggregate(*pipeline)
for group in Task.aggregate(pipeline)
}

View File

@@ -0,0 +1,89 @@
from datetime import timedelta, datetime
from time import sleep
from apierrors import errors
from bll.task import ChangeStatusRequest
from config import config
from database.model.task.task import TaskStatus, Task
from utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
class NonResponsiveTasksWatchdog:
threads = ThreadsManager()
class _Settings:
"""
Retrieves watchdog settings from the config file
The properties are not cached so that the updates in
the config file are reflected
"""
_prefix = "services.tasks.non_responsive_tasks_watchdog"
@property
def enabled(self):
return config.get(f"{self._prefix}.enabled", True)
@property
def watch_interval_sec(self):
return config.get(f"{self._prefix}.watch_interval_sec", 900)
@property
def threshold_sec(self):
return config.get(f"{self._prefix}.threshold_sec", 7200)
settings = _Settings()
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start(cls):
sleep(cls.settings.watch_interval_sec)
while not ThreadsManager.terminating:
watch_interval = cls.settings.watch_interval_sec
if cls.settings.enabled:
try:
stopped = cls.cleanup_tasks(
threshold_sec=cls.settings.threshold_sec
)
log.info(f"{stopped} non-responsive tasks stopped")
except Exception as ex:
log.exception(f"Failed stopping tasks: {str(ex)}")
sleep(watch_interval)
@classmethod
def cleanup_tasks(cls, threshold_sec):
relevant_status = (TaskStatus.in_progress,)
threshold = timedelta(seconds=threshold_sec)
ref_time = datetime.utcnow() - threshold
log.info(
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
)
tasks = list(
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
"id", "name", "status", "project", "last_update"
)
)
log.info(f"{len(tasks)} non-responsive tasks found")
if not tasks:
return 0
err_count = 0
for task in tasks:
log.info(
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
)
try:
ChangeStatusRequest(
task=task,
new_status=TaskStatus.stopped,
status_reason="Forced stop (non-responsive)",
status_message="Forced stop (non-responsive)",
force=True,
).execute()
except errors.bad_request.FailedChangingTaskStatus:
err_count += 1
return len(tasks) - err_count

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from operator import attrgetter
from random import random
from time import sleep
@@ -14,6 +14,7 @@ import database.utils as dbutils
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
from bll.organization import OrgBLL
from config import config
from database.errors import translate_errors_context
from database.model.model import Model
@@ -27,25 +28,37 @@ from database.model.task.task import (
TaskSystemTags,
ArtifactModes,
Artifact,
external_task_types,
)
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from services.utils import validate_tags
from timing_context import TimingContext
from utilities.dicts import deep_merge
from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
log = config.logger(__file__)
org_bll = OrgBLL()
class TaskBLL(object):
threads = ThreadsManager("TaskBLL")
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
@@ -167,9 +180,12 @@ class TaskBLL(object):
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
validate_tags(tags, system_tags)
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
@@ -177,6 +193,8 @@ class TaskBLL(object):
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
@@ -204,26 +222,42 @@ class TaskBLL(object):
else None,
execution=execution_dict,
)
cls.validate(new_task)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
return new_task
@classmethod
def validate(cls, task: Task):
assert isinstance(task, Task)
if task.parent and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
def validate(
cls,
task: Task,
validate_model=True,
validate_parent=True,
validate_project=True,
):
if (
validate_parent
and task.parent
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project and not Project.get_for_writing(
company=task.company, id=task.project
if (
validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
):
raise errors.bad_request.InvalidProjectId(id=task.project)
cls.validate_execution_model(task)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
@@ -263,7 +297,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = Task.aggregate(*pipeline)
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
@@ -327,7 +361,7 @@ class TaskBLL(object):
metric_stats = {
dbutils.hash_field_name(metric_key): MetricEventStats(
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
)
for metric_key, metric_data in last_events.items()
}
@@ -575,58 +609,6 @@ class TaskBLL(object):
return [a.key for a in added], [a.key for a in updated]
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start_non_responsive_tasks_watchdog(cls):
log = config.logger("non_responsive_tasks_watchdog")
relevant_status = (TaskStatus.in_progress,)
threshold = timedelta(
seconds=config.get(
"services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200
)
)
watch_interval = config.get(
"services.tasks.non_responsive_tasks_watchdog.watch_interval_sec", 900
)
sleep(watch_interval)
while not ThreadsManager.terminating:
try:
ref_time = datetime.utcnow() - threshold
log.info(
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
)
tasks = list(
Task.objects(
status__in=relevant_status, last_update__lt=ref_time
).only("id", "name", "status", "project", "last_update")
)
if tasks:
log.info(f"Stopping {len(tasks)} non-responsive tasks")
for task in tasks:
log.info(
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
)
ChangeStatusRequest(
task=task,
new_status=TaskStatus.stopped,
status_reason="Forced stop (non-responsive)",
status_message="Forced stop (non-responsive)",
force=True,
).execute()
log.info(f"Done")
except Exception as ex:
log.exception(f"Failed stopping tasks: {str(ex)}")
sleep(watch_interval)
@staticmethod
def get_aggregated_project_execution_parameters(
company_id,
@@ -666,7 +648,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = next(Task.aggregate(*pipeline), None)
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0

View File

@@ -33,8 +33,8 @@ log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es if es is not None else es_factory.connect("workers")
self.redis = redis if redis is not None else redman.connection("workers")
self.es_client = es or es_factory.connect("workers")
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@property
@@ -223,7 +223,7 @@ class WorkerBLL:
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(*projection)
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
task_ids = task_ids.union(
filter(

View File

@@ -57,7 +57,7 @@ class BasicConfig:
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
@@ -77,7 +77,7 @@ class BasicConfig:
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
]
if invalid:
print(f"WARNING: Invalid paths in {key} env var: {' '.join(invalid)}")
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
return [path for path in paths if path.is_dir()]
def _load(self, verbose=True):

View File

@@ -13,17 +13,21 @@
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
revoke_in_fixed_mode: true
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
}
}
}

View File

@@ -6,4 +6,8 @@ ignore_iteration {
# max number of concurrent queries to ES when calculating events metrics
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4
max_metrics_concurrency: 4
events_retrieval {
state_expiration_sec: 3600
}

View File

@@ -0,0 +1,3 @@
tags_cache {
expiration_seconds: 3600
}

View File

@@ -1,4 +1,6 @@
non_responsive_tasks_watchdog {
enabled: true
# In-progress tasks older than this value in seconds will be stopped by the watchdog
threshold_sec: 7200

View File

@@ -14,6 +14,9 @@ from mongoengine import (
DictField,
DynamicField,
)
from mongoengine.fields import key_not_string, key_starts_with_dollar
NoneType = type(None)
class LengthRangeListField(ListField):
@@ -125,17 +128,39 @@ def contains_empty_key(d):
return True
class SafeMapField(MapField):
class DictValidationMixin:
"""
DictField validation in MongoEngine requires default alias and permissions to access DB version:
https://github.com/MongoEngine/mongoengine/issues/2239
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
"""
def _safe_validate(self: DictField, value):
if not isinstance(value, dict):
self.error("Only dictionaries may be used in a DictField")
if key_not_string(value):
msg = "Invalid dictionary key - documents must have only string keys"
self.error(msg)
if key_starts_with_dollar(value):
self.error(
'Invalid dictionary key name - keys may not startswith "$" characters'
)
super(DictField, self).validate(value)
class SafeMapField(MapField, DictValidationMixin):
def validate(self, value):
super(SafeMapField, self).validate(value)
self._safe_validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a MapField")
class SafeDictField(DictField):
class SafeDictField(DictField, DictValidationMixin):
def validate(self, value):
super(SafeDictField, self).validate(value)
self._safe_validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a DictField")
@@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
SortedListField that does not raise an error in case items are not comparable
(in which case they will be sorted by their string representation)
"""
def to_mongo(self, *args, **kwargs):
try:
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
@@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
if self._ordering is not None:
def key(v): return str(itemgetter(self._ordering)(v))
def key(v):
return str(itemgetter(self._ordering)(v))
else:
key = str
return sorted(value, key=key, reverse=self._order_reverse)

View File

@@ -43,6 +43,7 @@ class Role(object):
class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
last_used = DateTimeField()

View File

@@ -3,7 +3,7 @@ from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional
from boltons.iterutils import first
from boltons.iterutils import first, bucketize
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
@@ -34,7 +34,12 @@ class AuthDocument(Document):
class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
def to_proper_dict(
self: Union["ProperDictMixin", Document],
strip_private=True,
only=None,
extra_dict=None,
) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
@@ -71,6 +76,8 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_field_collation_overrides = {}
class QueryParameterOptions(object):
def __init__(
self,
@@ -91,11 +98,48 @@ class GetMixin(PropsMixin):
self.list_fields = list_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_default = "in"
_ops = {"not": "nin"}
_next = _default
def __init__(self, legacy=False):
self._legacy = legacy
def key(self, v):
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"]
elif v.startswith(self.op_prefix):
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
return None
next_ = self._next
self._next = self._default
return next_
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
get_all_query_options = QueryParameterOptions()
@classmethod
def get(
cls, company, id, *, _only=None, include_public=False, **kwargs
cls: Union["GetMixin", Document],
company,
id,
*,
_only=None,
include_public=False,
**kwargs,
) -> "GetMixin":
q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public)
@@ -162,17 +206,7 @@ class GetMixin(PropsMixin):
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
exclude = [t for t in data if t.startswith("-")]
include = list(set(data).difference(exclude))
mongoengine_field = field.replace(".", "__")
if include:
dict_query[f"{mongoengine_field}__in"] = include
if exclude:
dict_query[f"{mongoengine_field}__nin"] = [
t[1:] for t in exclude
]
query &= cls.get_list_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
@@ -216,6 +250,47 @@ class GetMixin(PropsMixin):
return query & RegexQ(**dict_query)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
is included.
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
or by a preceding "__$not" value (operator)
"""
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
)
if not allow_empty:
return q
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
if allow_public:
@@ -409,7 +484,12 @@ class GetMixin(PropsMixin):
)
@classmethod
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
def _get_many_no_company(
cls: Union["GetMixin", Document],
query,
parameters=None,
override_projection=None,
):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
@@ -460,6 +540,8 @@ class GetMixin(PropsMixin):
"""
Fetch all documents matching a provided query. For the first order by field
the None values are sorted in the end regardless of the sorting order.
If the first order field is a user defined parameter (either from execution.parameters,
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
@@ -500,6 +582,16 @@ class GetMixin(PropsMixin):
query_sets = [cls.objects(non_empty), cls.objects(empty)]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(
v
for k, v in cls._field_collation_overrides.items()
if order_field.startswith(k)
)
if collation_override:
query_sets = [
qs.collation(collation=collation_override) for qs in query_sets
]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
@@ -593,7 +685,13 @@ class UpdateMixin(object):
return update_dict
@classmethod
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
def safe_update(
cls: Union["UpdateMixin", Document],
company_id,
id,
partial_update_dict,
injected_update=None,
):
update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict:
return 0, {}
@@ -610,7 +708,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
@classmethod
def aggregate(
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
cls: Union["DbModelMixin", Document],
pipeline: Sequence[dict],
allow_disk_use=None,
**kwargs,
) -> CommandCursor:
"""
Aggregate objects of this document class according to the provided pipeline.
@@ -625,7 +726,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
)
return cls.objects.aggregate(*pipeline, **kwargs)
return cls.objects.aggregate(pipeline, **kwargs)
def validate_id(cls, company, **kwargs):
@@ -647,5 +748,5 @@ def validate_id(cls, company, **kwargs):
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
)

View File

@@ -1,8 +1,9 @@
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
from mongoengine import Document, StringField, DateTimeField, BooleanField
from database import Database, strict
from database.fields import StrippedStringField, SafeDictField
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.model_labels import ModelLabels
from database.model.company import Company
from database.model.project import Project
@@ -19,6 +20,7 @@ class Model(DbModelMixin, Document):
"project",
"task",
("company", "name"),
("company", "user"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@@ -34,6 +36,21 @@ class Model(DbModelMixin, Document):
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"user",
"project",
"task",
"parent",
),
)
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
@@ -44,8 +61,8 @@ class Model(DbModelMixin, Document):
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
comment = StringField(user_set_allowed=True)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
uri = StrippedStringField(default="", user_set_allowed=True)
framework = StringField()
design = SafeDictField()

View File

@@ -1,11 +1,14 @@
from mongoengine import MapField, IntField
from database.fields import NoneType, UnionField, SafeMapField
class ModelLabels(MapField):
class ModelLabels(SafeMapField):
def __init__(self, *args, **kwargs):
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
super(ModelLabels, self).__init__(
field=UnionField(types=(int, NoneType)), *args, **kwargs
)
def validate(self, value):
super(ModelLabels, self).validate(value)
if value and (len(set(value.values())) < len(value)):
non_empty_values = list(filter(None, value.values()))
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
self.error("Same label id appears more than once in model labels")

View File

@@ -1,7 +1,7 @@
from mongoengine import StringField, DateTimeField, ListField
from mongoengine import StringField, DateTimeField
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import AttributedDocument
from database.model.base import GetMixin
@@ -36,7 +36,7 @@ class Project(AttributedDocument):
)
description = StringField(required=True)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True))
system_tags = ListField(StringField(required=True))
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))
default_output_destination = StrippedStringField()
last_update = DateTimeField()

View File

@@ -4,11 +4,10 @@ from mongoengine import (
StringField,
DateTimeField,
EmbeddedDocumentListField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.company import Company
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()

View File

@@ -7,6 +7,10 @@ from database import Database, strict
from database.model import DbModelMixin
class SettingKeys:
server__uuid = "server.uuid"
class Settings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
@@ -47,7 +51,7 @@ class Settings(DbModelMixin, Document):
""" Adds a new key/value settings. Fails if key already exists. """
key = key.strip(sep)
try:
res = Settings(key=key, value=value).save(force_insert=True)
res = cls(key=key, value=value).save(force_insert=True)
return bool(res)
except NotUniqueError:
return False

View File

@@ -18,7 +18,7 @@ from database.fields import (
SafeSortedListField,
)
from database.model import AttributedDocument
from database.model.base import ProperDictMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
@@ -100,9 +100,26 @@ class Execution(EmbeddedDocument, ProperDictMixin):
class TaskType(object):
training = "training"
testing = "testing"
inference = "inference"
data_processing = "data_processing"
application = "application"
monitor = "monitor"
controller = "controller"
optimizer = "optimizer"
service = "service"
qc = "qc"
custom = "custom"
external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -113,6 +130,7 @@ class Task(AttributedDocument):
"parent",
"project",
("company", "name"),
("company", "user"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
@@ -140,6 +158,12 @@ class Task(AttributedDocument):
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
id = StringField(primary_key=True)
name = StrippedStringField(
@@ -158,11 +182,11 @@ class Task(AttributedDocument):
published = DateTimeField()
parent = StringField()
project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output)
output: Output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script: Script = EmbeddedDocumentField(Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()

View File

@@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField
from database import Database, strict
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.company import Company
class User(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
"db_alias": Database.backend,
"strict": strict,
}
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
id = StringField(primary_key=True)
company = StringField(required=True, reference_field=Company)

View File

@@ -1,8 +1,14 @@
import copy
import re
from typing import Union
from mongoengine import Q
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
from mongoengine.queryset.visitor import (
QueryCompilerVisitor,
SimplificationVisitor,
QCombination,
QNode,
)
class RegexWrapper(object):
@@ -17,17 +23,16 @@ class RegexWrapper(object):
class RegexMixin(object):
def to_query(self, document):
def to_query(self: Union["RegexMixin", QNode], document):
query = self.accept(SimplificationVisitor())
query = query.accept(RegexQueryCompilerVisitor(document))
return query
def _combine(self, other, operation):
def _combine(self: Union["RegexMixin", QNode], other, operation):
"""Combine this node with another node into a QCombination
object.
"""
if getattr(other, 'empty', True):
if getattr(other, "empty", True):
return self
if self.empty:

View File

@@ -95,26 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
res[field] = None
continue
if desc:
if callable(desc):
if issubclass(desc, Document):
if not desc.objects(id=value).only("id"):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
elif callable(desc):
try:
desc(value)
except TypeError:
raise ParseCallError(f"expecting {desc.__name__}", field=field)
except Exception as ex:
raise ParseCallError(str(ex), field=field)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(
value, desc
):
raise ParseCallError(
"expecting %s" % desc.__name__, field=field
)
if issubclass(desc, Document) and not desc.objects(id=value).only(
"id"
):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
res[field] = value
return res

View File

@@ -38,27 +38,22 @@ def init_mongo_data():
PrePopulate.import_from_zip(zip_file, user_id=user_id)
users = [
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
fixed_mode = FixedUser.enabled()
for user in users:
credentials = config.get(f"secure.credentials.{user['name']}")
user["key"] = credentials.user_key
user["secret"] = credentials.user_secret
_ensure_auth_user(user, company_id, log=log)
for user, credentials in config.get("secure.credentials", {}).items():
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"key": credentials.user_key,
"secret": credentials.user_secret,
}
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)
if FixedUser.enabled():
if fixed_mode:
log.info("Fixed users mode is enabled")
FixedUser.validate()
for user in FixedUser.from_config():

View File

@@ -9,7 +9,7 @@ from database.model.user import User
from service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger):
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
ensure_credentials = {"key", "secret"}.issubset(user_data)
if ensure_credentials:
user = AuthUser.objects(
@@ -18,17 +18,22 @@ def _ensure_auth_user(user_data: dict, company_id: str, log: Logger):
)
).first()
if user:
if revoke:
user.credentials = []
user.save()
return user.id
user_id = user_data.get("id", f"__{user_data['name']}__")
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_data.get("id", f"__{user_data['name']}__"),
id=user_id,
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])] if not revoke else []
if ensure_credentials
else None,
)

View File

@@ -6,7 +6,7 @@ from config import config
from config.info import get_default_company
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings
from database.model.settings import Settings, SettingKeys
log = config.logger(__file__)
@@ -37,4 +37,4 @@ def _ensure_default_queue(company):
def _ensure_uuid():
Settings.add_value("server.uuid", str(uuid4()))
Settings.add_value(SettingKeys.server__uuid, str(uuid4()))

View File

@@ -0,0 +1,58 @@
from collections import Collection
from typing import Sequence
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
for collection_name in db.list_collection_names():
if collection_name not in names:
continue
collection: Collection = db[collection_name]
collection.drop_indexes()
def migrate_auth(db: Database):
"""
Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
"""
_drop_all_indices_from_collections(db, ["user"])
def migrate_backend(db: Database):
"""
1. Sort tags and system tags
2. Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
"""
fields = ("tags", "system_tags")
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
for collection_name in ("task", "model", "project", "queue"):
collection = db[collection_name]
for doc in collection.find(filter=query, projection=fields):
update = {
field: sorted(doc[field])
for field in fields
if doc.get(field)
}
if update:
collection.update_one({"_id": doc["_id"]}, {"$set": update})
_drop_all_indices_from_collections(
db,
[
"company",
"model",
"project",
"queue",
"settings",
"task",
"task__trash",
"user",
"versions",
],
)

View File

@@ -14,12 +14,12 @@ Jinja2==2.10
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.7.2
mongoengine==0.16.2
mongoengine==0.19.1
nested_dict>=1.61
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt>=1.3.0
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
pymongo==3.10.1
python-rapidjson>=0.6.3
redis>=2.10.5
related>=0.7.2

View File

@@ -530,59 +530,59 @@
}
}
}
"2.7" {
description: "Get 'log' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
batch_size {
type: integer
description: "The amount of log events to return"
}
navigate_earlier {
type: boolean
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
}
refresh {
type: boolean
description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
events {
type: array
items { type: object }
description: "Log items list"
}
returned {
type: integer
description: "Number of log events returned"
}
total {
type: number
description: "Total number of log events available for this query"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
// "2.7" {
// description: "Get 'log' events for this task"
// request {
// type: object
// required: [
// task
// ]
// properties {
// task {
// type: string
// description: "Task ID"
// }
// batch_size {
// type: integer
// description: "The amount of log events to return"
// }
// navigate_earlier {
// type: boolean
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
// }
// refresh {
// type: boolean
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
// }
// scroll_id {
// type: string
// description: "Scroll ID of previous call (used for getting more results)"
// }
// }
// }
// response {
// type: object
// properties {
// events {
// type: array
// items { type: object }
// description: "Log items list"
// }
// returned {
// type: integer
// description: "Number of log events returned"
// }
// total {
// type: number
// description: "Total number of log events available for this query"
// }
// scroll_id {
// type: string
// description: "Scroll ID for getting more results"
// }
// }
// }
// }
}
get_task_events {
"2.1" {

View File

@@ -159,6 +159,11 @@
description: "Get only models whose name matches this pattern (python regular expression syntax)"
type: string
}
user {
description: "List of user IDs used to filter results by the model's creating user"
type: array
items { type: string }
}
ready {
description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects."
type: boolean

View File

@@ -0,0 +1,43 @@
_description: "This service provides organization level operations"
get_tags {
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
}
}
response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
}

View File

@@ -69,6 +69,17 @@ info {
}
}
}
"2.8": ${info."2.1"} {
response {
type: object
properties {
uid {
description: "Server UID"
type: string
}
}
}
}
}
endpoints {
"2.1" {

View File

@@ -254,6 +254,15 @@ _definitions {
enum: [
training
testing
inference
data_processing
application
monitor
controller
optimizer
service
qc
custom
]
}
last_metrics_event {
@@ -475,7 +484,11 @@ get_all {
minimum: 1
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
description: """List of field names to order by. When search_text is used,
'@text_score' can be used as a field representing the text score of returned documents.
Use '-' prefix to specify descending order. Optional, recommended when using page.
If the first order field is a hyper parameter or metric then string values are ordered
according to numeric ordering rules where applicable"""
type: array
items { type: string }
}
@@ -550,6 +563,31 @@ get_all {
}
}
}
get_types {
"2.8" {
description: "Get the list of task types used in the specified projects"
request {
type: object
properties {
projects {
description: "The list of projects which tasks will be analyzed. If not passed or empty then all the company and public tasks will be analyzed"
type: array
items: {type: string}
}
}
}
response {
type: object
properties {
types {
description: "Unique list of the task types used in the requested projects"
type: array
items: {type: string}
}
}
}
}
}
clone {
"2.5" {
description: "Clone an existing task"
@@ -591,6 +629,10 @@ clone {
description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution"
}
validate_references {
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
type: boolean
}
}
}
response {
@@ -901,6 +943,11 @@ reset {
properties.force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'"
}
properties.clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
} ${_references.status_change_request}
response {
type: object

View File

@@ -145,6 +145,19 @@ get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
}
"2.8": ${get_all."2.1"} {
request {
type: object
properties {
active_in_projects {
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
type: array
items { type: string }
}
}
}
}
}
get_all {

View File

@@ -52,7 +52,7 @@ def validate_all(call: APICall, endpoint: Endpoint):
def validate_role(endpoint, call):
try:
if not endpoint.allows(call.identity.role):
if endpoint.authorize and not endpoint.allows(call.identity.role):
raise errors.forbidden.RoleNotAllowed(role=call.identity.role, allowed=endpoint.allow_roles)
except MissingIdentity:
pass

View File

@@ -11,6 +11,7 @@ from apimodels.events import (
MetricEvents,
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
)
from bll.event import EventBLL
from bll.event.event_metrics import EventMetrics
@@ -26,10 +27,10 @@ event_bll = EventBLL()
def add(call: APICall, company_id, req_model):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, batch_errors = event_bll.add_events(
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=len(batch_errors))
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1
@@ -39,13 +40,13 @@ def add_batch(call: APICall, company_id, req_model):
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=len(batch_errors))
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log(call, company_id, req_model):
def get_task_log_v1_5(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
order = call.data.get("order") or "desc"
@@ -93,6 +94,29 @@ def get_task_log_v1_7(call, company_id, req_model):
)
# uncomment this once the front end is ready
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
# def get_task_log(call, company_id, req_model: LogEventsRequest):
# task_id = req_model.task
# task_bll.assert_exists(company_id, task_id, allow_public=True)
#
# res = event_bll.log_events_iterator.get_task_events(
# company_id=company_id,
# task_id=task_id,
# batch_size=req_model.batch_size,
# navigate_earlier=req_model.navigate_earlier,
# refresh=req_model.refresh,
# state_id=req_model.scroll_id,
# )
#
# call.result.data = dict(
# events=res.events,
# returned=len(res.events),
# total=res.total_events,
# scroll_id=res.next_scroll_id,
# )
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, req_model):
task_id = call.data["task"]

View File

@@ -12,6 +12,7 @@ from apimodels.models import (
PublishModelResponse,
ModelTaskPublishResponse,
)
from bll.organization import OrgBLL
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
@@ -29,51 +30,34 @@ from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
log = config.logger(__file__)
get_all_query_options = Model.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"project",
"task",
"parent",
),
)
org_bll = OrgBLL()
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call):
assert isinstance(call, APICall)
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
models = Model.get_many(
company=call.identity.company,
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
"no such public or company model", id=model_id, company=company_id,
)
conform_output_tags(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call):
assert isinstance(call, APICall)
def get_by_task_id(call: APICall, company_id, _):
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=call.identity.company)
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["output"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
@@ -84,13 +68,11 @@ def get_by_task_id(call):
model_id = task.output.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
@@ -98,31 +80,27 @@ def get_by_task_id(call):
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
models = Model.get_many(
company=call.identity.company,
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
@@ -146,13 +124,18 @@ create_fields = {
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
return fields
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
)
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call, company_id, _):
assert isinstance(call, APICall)
def update_for_task(call: APICall, company_id, _):
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
@@ -195,7 +178,9 @@ def update_for_task(call, company_id, _):
if task.output and task.output.model:
# model exists, update
res = _update_model(call, model_id=task.output.model).to_struct()
res = _update_model(
call, company_id, model_id=task.output.model
).to_struct()
res.update({"id": task.output.model, "created": False})
call.result.data = res
return
@@ -218,6 +203,7 @@ def update_for_task(call, company_id, _):
**fields,
)
model.save()
_update_org_tags(company_id, fields)
TaskBLL.update_statistics(
task_id=task_id,
@@ -234,48 +220,46 @@ def update_for_task(call, company_id, _):
request_data_model=CreateModelRequest,
response_data_model=CreateModelResponse,
)
def create(call, company, req_model):
assert isinstance(call, APICall)
assert isinstance(req_model, CreateModelRequest)
identity = call.identity
def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company = ""
company_id = ""
with translate_errors_context():
project = req_model.project
if project:
validate_id(Project, company=company, project=project)
validate_id(Project, company=company_id, project=project)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(call, req_data)
validate_task(company_id, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
# create and save model
model = Model(
id=database.utils.id(),
user=identity.user,
company=company,
user=call.identity.user,
company=company_id,
created=datetime.utcnow(),
**fields,
)
model.save()
_update_org_tags(company_id, fields)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
def prepare_update_fields(call, fields):
def prepare_update_fields(call, company_id, fields: dict):
fields = fields.copy()
if "uri" in fields:
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(call, fields)
validate_task(company_id, fields)
if "labels" in fields:
labels = fields["labels"]
@@ -290,33 +274,36 @@ def prepare_update_fields(call, fields):
invalid_keys = find_other_types(labels.keys(), str)
if invalid_keys:
raise errors.bad_request.ValidationError("labels keys must be strings", keys=invalid_keys)
raise errors.bad_request.ValidationError(
"labels keys must be strings", keys=invalid_keys
)
invalid_values = find_other_types(labels.values(), int)
if invalid_values:
raise errors.bad_request.ValidationError("labels values must be integers", values=invalid_values)
raise errors.bad_request.ValidationError(
"labels values must be integers", values=invalid_values
)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
return fields
def validate_task(call, fields):
Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])
def validate_task(company_id, fields: dict):
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall):
identity = call.identity
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, fields)
fields = prepare_update_fields(call, company_id, fields)
for key in fields:
field = getattr(model, key, None)
@@ -331,47 +318,44 @@ def edit(call: APICall):
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get('task')
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
if fields:
updated = model.update(upsert=False, **fields)
if updated:
_update_org_tags(company_id, fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, model_id=None):
identity = call.identity
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
# get model by id
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
data = prepare_update_fields(call, call.data)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
updated_count, updated_fields = Model.safe_update(
call.identity.company, model.id, data
)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
_update_org_tags(company_id, updated_fields)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -379,8 +363,8 @@ def _update_model(call: APICall, model_id=None):
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call):
call.result.data_model = _update_model(call)
def update(call, company_id, _):
call.result.data_model = _update_model(call, company_id)
@endpoint(
@@ -388,31 +372,29 @@ def update(call):
request_data_model=PublishModelRequest,
response_data_model=PublishModelResponse,
)
def set_ready(call: APICall, company, req_model: PublishModelRequest):
def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
updated, published_task_data = TaskBLL.model_set_ready(
model_id=req_model.model,
company_id=company,
company_id=company_id,
publish_task=req_model.publish_task,
force_publish_task=req_model.force_publish_task
force_publish_task=req_model.force_publish_task,
)
call.result.data_model = PublishModelResponse(
updated=updated,
published_task=ModelTaskPublishResponse(
**published_task_data
) if published_task_data else None
published_task=ModelTaskPublishResponse(**published_task_data)
if published_task_data
else None,
)
@endpoint("models.delete", required_fields=["model"])
def update(call):
assert isinstance(call, APICall)
identity = call.identity
def update(call: APICall, company_id, _):
model_id = call.data["model"]
force = call.data.get("force", False)
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).only("id", "task").first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
@@ -445,4 +427,6 @@ def update(call):
)
del_count = Model.objects(**query).delete()
if del_count:
org_bll.update_org_tags(company_id, reset=True)
call.result.data = dict(deleted=del_count > 0)

View File

@@ -0,0 +1,13 @@
from apimodels.organization import TagsRequest
from bll.organization import OrgBLL
from service_repo import endpoint, APICall
org_bll = OrgBLL()
@endpoint("organization.get_tags", request_data_model=TagsRequest)
def get_tags(call: APICall, company, request: TagsRequest):
filter_ = request.filter.system_tags if request.filter else None
call.result.data = org_bll.get_tags(
company, include_system=request.include_system, filter_=filter_
)

View File

@@ -154,11 +154,9 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing", "annotation"]},
"project": {
"company": {"$in": [None, "", company_id]},
"$in": project_ids,
},
"type": {"$in": ["training", "testing"]},
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
@@ -276,7 +274,7 @@ def create(call):
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -313,7 +311,7 @@ def update(call: APICall):
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)

View File

@@ -58,7 +58,9 @@ def get_all(call: APICall):
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags(call, request.tags, request.system_tags)
tags, system_tags = conform_tags(
call, request.tags, request.system_tags, validate=True
)
queue = queue_bll.create(
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
)
@@ -73,7 +75,7 @@ def create(call: APICall, company_id, request: CreateRequest):
)
def update(call: APICall, company_id, req_model: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data)
conform_tag_fields(call, data, validate=True)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
@@ -212,7 +214,9 @@ def get_queue_metrics(
dates=data["date"],
avg_waiting_times=data["avg_waiting_time"],
queue_lengths=data["queue_length"],
) if data else QueueMetrics(queue=queue)
)
if data
else QueueMetrics(queue=queue)
for queue, data in queue_dicts.items()
]
)

View File

@@ -10,6 +10,7 @@ from config.info import get_version, get_build_number, get_commit_number
from database.errors import translate_errors_context
from database.model import Company
from database.model.company import ReportStatsOption
from database.model.settings import Settings, SettingKeys
from service_repo import ServiceRepo, APICall, endpoint
@@ -60,6 +61,12 @@ def info(call: APICall):
}
@endpoint("server.info", min_version="2.8")
def info_2_8(call: APICall):
info(call)
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
@endpoint(
"server.report_stats_option",
request_data_model=ReportStatsOptionRequest,

View File

@@ -1,7 +1,7 @@
from copy import deepcopy
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Callable, Type, TypeVar, Union
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
import attr
import dpath
@@ -29,8 +29,11 @@ from apimodels.tasks import (
CloneRequest,
AddOrUpdateArtifactsRequest,
AddOrUpdateArtifactsResponse,
GetTypesRequest,
ResetRequest,
)
from bll.event import EventBLL
from bll.organization import OrgBLL
from bll.queue import QueueBLL
from bll.task import (
TaskBLL,
@@ -39,6 +42,7 @@ from bll.task import (
split_by,
ParameterKeyEscaper,
)
from bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from bll.util import SetFieldsResolver
from database.errors import translate_errors_context
from database.model.model import Model
@@ -58,19 +62,13 @@ from utilities import safe_get
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
get_all_query_options = Task.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
task_bll = TaskBLL()
event_bll = EventBLL()
queue_bll = QueueBLL()
org_bll = OrgBLL()
TaskBLL.start_non_responsive_tasks_watchdog()
NonResponsiveTasksWatchdog.start()
def set_task_status_from_call(
@@ -110,12 +108,18 @@ def escape_execution_parameters(call: APICall):
default_prefix = "execution.parameters."
def escape_paths(paths, prefix=default_prefix):
return [
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
if path.startswith(prefix)
else path
for path in paths
]
escaped_paths = []
for path in paths:
if path == prefix:
raise errors.bad_request.ValidationError(
"invalid task field", path=path
)
escaped_paths.append(
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
if path.startswith(prefix)
else path
)
return escaped_paths
projection = Task.get_projection(call.data)
if projection:
@@ -128,7 +132,7 @@ def escape_execution_parameters(call: APICall):
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
@@ -136,9 +140,8 @@ def get_all_ex(call: APICall):
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
tasks = Task.get_many_with_join(
company=call.identity.company,
company=company_id,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
unprepare_from_saved(call, tasks)
@@ -146,7 +149,7 @@ def get_all_ex(call: APICall):
@endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
@@ -154,16 +157,22 @@ def get_all(call: APICall):
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=call.identity.company,
company=company_id,
parameters=call.data,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
def get_types(call: APICall, company_id, request: GetTypesRequest):
call.result.data = {
"types": list(task_bll.get_types(company_id, project_ids=request.projects))
}
@endpoint(
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
)
@@ -256,7 +265,7 @@ create_fields = {
def prepare_for_save(call: APICall, fields: dict):
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
@@ -316,7 +325,7 @@ def prepare_create_fields(
return prepare_for_save(call, fields)
def _validate_and_get_task_from_call(call: APICall, **kwargs):
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_call"):
@@ -326,7 +335,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs):
with TimingContext("code", "validate"):
task_bll.validate(task)
return task
return task, fields
@endpoint("tasks.validate", request_data_model=CreateRequest)
@@ -334,14 +343,21 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
_validate_and_get_task_from_call(call)
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
)
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
task = _validate_and_get_task_from_call(call)
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context(), TimingContext("mongo", "save_task"):
task.save()
_update_org_tags(company_id, fields)
update_project_time(task.project)
call.result.data_model = IdResponse(id=task.id)
@@ -362,6 +378,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
tags=request.new_task_tags,
system_tags=request.new_task_system_tags,
execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
)
call.result.data_model = IdResponse(id=task.id)
@@ -398,8 +415,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
partial_update_dict=partial_update_dict,
injected_update=dict(last_update=datetime.utcnow()),
)
update_project_time(updated_fields.get("project"))
if updated_count:
_update_org_tags(company_id, updated_fields)
update_project_time(updated_fields.get("project"))
unprepare_from_saved(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -431,9 +449,7 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
@endpoint("tasks.update_batch")
def update_batch(call: APICall):
identity = call.identity
def update_batch(call: APICall, company_id, _):
items = call.batched_data
if items is None:
raise errors.bad_request.BatchContainsNoItems()
@@ -443,7 +459,7 @@ def update_batch(call: APICall):
tasks = {
t.id: t
for t in Task.get_many_for_writing(
company=identity.company, query=Q(id__in=list(items))
company=company_id, query=Q(id__in=list(items))
)
}
@@ -461,7 +477,7 @@ def update_batch(call: APICall):
continue
partial_update_dict.update(last_update=now)
update_op = UpdateOne(
{"_id": id, "company": identity.company}, {"$set": partial_update_dict}
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
)
bulk_ops.append(update_op)
@@ -469,7 +485,8 @@ def update_batch(call: APICall):
if bulk_ops:
res = Task._get_collection().bulk_write(bulk_ops)
updated = res.modified_count
if updated:
org_bll.update_org_tags(company_id, reset=True)
call.result.data = {"updated": updated}
@@ -524,7 +541,9 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
fields.update(last_update=now)
fixed_fields.update(last_update=now)
updated = task.update(upsert=False, **fixed_fields)
update_project_time(fields.get("project"))
if updated:
_update_org_tags(company_id, fixed_fields)
update_project_time(fields.get("project"))
unprepare_from_saved(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
@@ -651,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False):
@endpoint(
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
)
def reset(call: APICall, company_id, req_model: UpdateRequest):
def reset(call: APICall, company_id, request: ResetRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, requires_write_access=True
request.task, company_id=company_id, requires_write_access=True
)
force = req_model.force
force = request.force
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@@ -674,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
else:
if dequeued:
api_results.update(dequeued=dequeued)
updates.update(unset__execution__queue=1)
cleaned_up = cleanup_task(task, force)
api_results.update(attr.asdict(cleaned_up))
@@ -682,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
unset__output__result=1,
unset__output__model=1,
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if request.clear_all:
updates.update(
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
updates.update(
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
)
res = ResetResponse(
**ChangeStatusRequest(
task=task,
@@ -808,6 +840,15 @@ def get_outputs_for_deletion(task, force=False):
else:
models.draft.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = [m.id for m in models.draft]
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
)
busy_models = [t.execution.model for t in dependent_tasks]
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
with TimingContext("mongo", "get_task_children"):
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
published_tasks = [
@@ -868,7 +909,7 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
task.switch_collection(collection_name)
task.delete()
org_bll.update_org_tags(company_id, reset=True)
call.result.data = dict(deleted=True, **attr.asdict(result))

View File

@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, Tuple
from typing import Tuple
import dpath
from boltons.iterutils import remap
@@ -8,6 +8,7 @@ from mongoengine import Q
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.users import CreateRequest, SetPreferencesRequest
from bll.project import ProjectBLL
from bll.user import UserBLL
from config import config
from database.errors import translate_errors_context
@@ -19,10 +20,10 @@ from service_repo import APICall, endpoint
from utilities.json import loads, dumps
log = config.logger(__file__)
get_all_query_options = User.QueryParameterOptions(list_fields=("id",))
project_bll = ProjectBLL()
def get_user(call, user_id, only=None):
def get_user(call, company_id, user_id, only=None):
"""
Get user object by the user's ID
:param call: API call
@@ -34,7 +35,7 @@ def get_user(call, user_id, only=None):
# allow system users to get info for all users
query = dict(id=user_id)
else:
query = dict(id=user_id, company=call.identity.company)
query = dict(id=user_id, company=company_id)
with translate_errors_context("retrieving user"):
user = User.objects(**query)
@@ -48,47 +49,53 @@ def get_user(call, user_id, only=None):
@endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call):
assert isinstance(call, APICall)
def get_by_id(call: APICall, company_id, _):
user_id = call.data["user"]
call.result.data = {"user": get_user(call, user_id)}
call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[])
def get_all_ex(call):
assert isinstance(call, APICall)
def get_all_ex(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
)
res = User.get_many_with_join(company=company_id, query_dict=call.data)
call.result.data = {"users": res}
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
active_in_projects = call.data.get("active_in_projects", None)
if active_in_projects is not None:
active_users = project_bll.get_active_users(
company_id, active_in_projects, call.data.get("id")
)
active_users.discard(None)
if not active_users:
call.result.data = {"users": []}
return
data = data.copy()
data["id"] = list(active_users)
res = User.get_many_with_join(company=company_id, query_dict=data)
call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[])
def get_all(call):
assert isinstance(call, APICall)
def get_all(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many(
company=call.identity.company,
parameters=call.data,
query_dict=call.data,
query_options=get_all_query_options,
company=company_id, parameters=call.data, query_dict=call.data
)
call.result.data = {"users": res}
@endpoint("users.get_current_user")
def get_current_user(call):
assert isinstance(call, APICall)
def get_current_user(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
projection = (
{"company.name"}
.union(User.get_fields())
@@ -96,7 +103,7 @@ def get_current_user(call):
)
res = User.get_many_with_join(
query=Q(id=call.identity.user),
company=call.identity.company,
company=company_id,
override_projection=projection,
)
@@ -126,13 +133,11 @@ def create(call: APICall):
@endpoint("users.delete", required_fields=["user"])
def delete(call):
assert isinstance(call, APICall)
def delete(call: APICall):
UserBLL.delete(call.data["user"])
def update_user(user_id, company_id, data):
# type: (str, str, Dict) -> Tuple[int, Dict]
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
"""
Update user.
:param user_id: user ID to update
@@ -150,31 +155,29 @@ def update_user(user_id, company_id, data):
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _):
assert isinstance(call, APICall)
user_id = call.data["user"]
update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
def get_user_preferences(call):
def get_user_preferences(call: APICall, company_id):
user_id = call.identity.user
preferences = get_user(call, user_id, ["preferences"]).get("preferences")
preferences = get_user(call, company_id, user_id, only=["preferences"]).get(
"preferences"
)
if preferences and isinstance(preferences, str):
preferences = loads(preferences)
return preferences or {}
@endpoint("users.get_preferences")
def get_preferences(call):
assert isinstance(call, APICall)
return {"preferences": get_user_preferences(call)}
def get_preferences(call: APICall, company_id, _):
return {"preferences": get_user_preferences(call, company_id)}
@endpoint("users.set_preferences", request_data_model=SetPreferencesRequest)
def set_preferences(call, company_id, req_model):
# type: (APICall, str, SetPreferencesRequest) -> Dict
assert isinstance(call, APICall)
changes = req_model.preferences
def set_preferences(call: APICall, company_id, request: SetPreferencesRequest):
changes = request.preferences
def invalid_key(_, key, __):
if not isinstance(key, str):
@@ -187,7 +190,7 @@ def set_preferences(call, company_id, req_model):
remap(changes, visit=invalid_key)
base_preferences = get_user_preferences(call)
base_preferences = get_user_preferences(call, company_id)
new_preferences = deepcopy(base_preferences)
for key, value in changes.items():
try:

View File

@@ -1,5 +1,7 @@
from typing import Union, Sequence, Tuple
from apierrors import errors
from database.model.base import GetMixin
from database.utils import partition_tags
from service_repo import APICall
from service_repo.base import PartialVersion
@@ -19,13 +21,13 @@ def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
def conform_tag_fields(call: APICall, document: dict):
def conform_tag_fields(call: APICall, document: dict, validate=False):
"""
Upgrade old client tags in place
"""
if "tags" in document:
tags, system_tags = conform_tags(
call, document["tags"], document.get("system_tags")
call, document["tags"], document.get("system_tags"), validate
)
if tags != document.get("tags"):
document["tags"] = tags
@@ -34,16 +36,18 @@ def conform_tag_fields(call: APICall, document: dict):
def conform_tags(
call: APICall, tags: Sequence, system_tags: Sequence
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
) -> Tuple[Sequence, Sequence]:
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if validate:
validate_tags(tags, system_tags)
if call.requested_endpoint_version < PartialVersion("2.3"):
tags, system_tags = _upgrade_tags(call, tags, system_tags)
return _get_unique_values(tags), _get_unique_values(system_tags)
return tags, system_tags
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
@@ -55,9 +59,12 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
return tags, system_tags
def _get_unique_values(values: Sequence) -> Sequence:
"""Get unique values from the given sequence"""
if not values:
return values
return list(set(values))
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)

View File

@@ -54,6 +54,10 @@ class TestService(TestCase, TestServiceInterface):
)
return object_id
@staticmethod
def update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def create_temp(self, service, *, client=None, delete_params=None, **kwargs) -> str:
return self._create_temp_helper(
service=service,

View File

@@ -61,7 +61,7 @@ class TestEntityOrdering(TestService):
page_size=page_size,
).tasks
def _assertSorted(self, vals: Sequence, ascending=True):
def _assertSorted(self, vals: Sequence, ascending=True, is_numeric=False):
"""
Assert that vals are sorted in the ascending or descending order
with None values are always coming from the end
@@ -80,6 +80,9 @@ class TestEntityOrdering(TestService):
self.assertTrue(all(val == empty_value for val in none_tail))
self.assertTrue(all(val != empty_value for val in vals))
if is_numeric:
vals = list(map(int, vals))
if ascending:
cmp = operator.le
else:
@@ -106,14 +109,18 @@ class TestEntityOrdering(TestService):
# test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
self._assertSorted(field_vals, ascending=not order_by.startswith("-"))
self._assertSorted(
field_vals,
ascending=not order_by.startswith("-"),
is_numeric=field_name.startswith("execution.parameters.")
)
def _create_tasks(self):
tasks = [
self._temp_task(
**(dict(execution={"parameters": {"test": f"{i}"} if i >= 5 else {}}))
)
for i in range(10)
for i in range(20)
]
for idx, task in zip(range(5), tasks):
self.api.tasks.started(task=task)

View File

@@ -0,0 +1,36 @@
from tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.8"):
super().setUp(version=version)
def test_tags(self):
tag1 = "Orgtest tag1"
tag2 = "Orgtest tag2"
system_tag = "Orgtest system tag"
model = self.create_temp(
"models", name="test_org", uri="file:///a", tags=[tag1]
)
task = self.create_temp(
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
)
data = self.api.organization.get_tags()
self.assertTrue(tag1 in data.tags)
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
data = self.api.organization.get_tags(include_system=True)
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
self.assertTrue(system_tag in data.system_tags)
data = self.api.organization.get_tags(
filter={"system_tags": ["__$not", system_tag]}
)
self.assertTrue(tag1 in data.tags)
self.assertFalse(tag2 in data.tags)
self.api.models.delete(model=model)
data = self.api.organization.get_tags()
self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags)

View File

@@ -208,25 +208,21 @@ class TestTags(TestService):
self.api.tasks.stopped(task=task_id)
def _temp_queue(self, **kwargs):
self._update_missing(kwargs, name="Test tags")
self.update_missing(kwargs, name="Test tags")
return self.create_temp("queues", **kwargs)
def _temp_project(self, **kwargs):
self._update_missing(kwargs, name="Test tags", description="test")
self.update_missing(kwargs, name="Test tags", description="test")
return self.create_temp("projects", **kwargs)
def _temp_model(self, **kwargs):
self._update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
self.update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs)
def _temp_task(self, **kwargs):
self._update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
self.update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
return self.create_temp("tasks", **kwargs)
@staticmethod
def _update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def _send(self, service, action, **kwargs):
api = kwargs.pop("api", self.api)
return AttrDict(

View File

@@ -2,13 +2,14 @@
Comprehensive test of all(?) use cases of datasets and frames
"""
import json
import time
import operator
import unittest
from functools import partial
from statistics import mean
from typing import Sequence
import es_factory
from apierrors.errors.bad_request import EventsNotAdded
from tests.automated import TestService
@@ -22,21 +23,17 @@ class TestTaskEvents(TestService):
)
return self.create_temp("tasks", **task_input)
def _create_task_event(self, type_, task, iteration, **kwargs):
@staticmethod
def _create_task_event(type_, task, iteration, **kwargs):
return {
"worker": "test",
"type": type_,
"task": task,
"iter": iteration,
"timestamp": es_factory.get_timestamp_millis(),
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
**kwargs,
}
def _copy_and_update(self, src_obj, new_data):
obj = src_obj.copy()
obj.update(new_data)
return obj
def test_task_metrics(self):
tasks = {
self._temp_task(): {
@@ -83,8 +80,7 @@ class TestTaskEvents(TestService):
# test empty
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
metrics=[{"task": task, "metric": metric}], iters=5,
)
self.assertFalse(res.metrics)
@@ -116,11 +112,11 @@ class TestTaskEvents(TestService):
# test forward navigation
for page in range(3):
scroll_id = assert_debug_images(scroll_id=scroll_id, page=page)
scroll_id = assert_debug_images(scroll_id=scroll_id, expected_page=page)
# test backwards navigation
scroll_id = assert_debug_images(
scroll_id=scroll_id, page=0, navigate_earlier=False
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
)
# beyond the latest iteration and back
@@ -131,10 +127,10 @@ class TestTaskEvents(TestService):
navigate_earlier=False,
)
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
assert_debug_images(scroll_id=scroll_id, page=1)
assert_debug_images(scroll_id=scroll_id, expected_page=1)
# refresh
assert_debug_images(scroll_id=scroll_id, page=0, refresh=True)
assert_debug_images(scroll_id=scroll_id, expected_page=0, refresh=True)
def _assertDebugImages(
self,
@@ -143,7 +139,7 @@ class TestTaskEvents(TestService):
max_iter: int,
unique_images: Sequence[int],
scroll_id,
page: int,
expected_page: int,
iters: int = 5,
**extra_params,
):
@@ -156,7 +152,7 @@ class TestTaskEvents(TestService):
data = res["metrics"][0]
self.assertEqual(data["task"], task)
self.assertEqual(data["metric"], metric)
left_iterations = max(0, max(unique_images) - page * iters)
left_iterations = max(0, max(unique_images) - expected_page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:
events_per_iter = sum(
@@ -165,27 +161,93 @@ class TestTaskEvents(TestService):
self.assertEqual(len(it["events"]), events_per_iter)
return res.scroll_id
def test_task_logs(self):
events = []
def test_error_events(self):
task = self._temp_task()
for iter_ in range(10):
log_event = self._create_task_event("log", task, iteration=iter_)
events.append(
self._copy_and_update(
log_event,
{"msg": "This is a log message from test task iter " + str(iter_)},
)
events = [
self._create_task_event("unknown type", task, iteration=1),
self._create_task_event("training_debug_image", task=None, iteration=1),
self._create_task_event(
"training_debug_image", task="Invalid task", iteration=1
),
]
# failure if no events added
with self.api.raises(EventsNotAdded):
self.send_batch(events)
events.append(
self._create_task_event("training_debug_image", task=task, iteration=1)
)
# success if at least one event added
res = self.send_batch(events)
self.assertEqual(res["added"], 1)
self.assertEqual(res["errors"], 3)
self.assertEqual(len(res["errors_info"]), 3)
res = self.api.events.get_task_events(task=task)
self.assertEqual(len(res.events), 1)
def test_task_logs(self):
# this test will fail until the new api is uncommented
task = self._temp_task()
timestamp = es_factory.get_timestamp_millis()
events = [
self._create_task_event(
"log",
task=task,
iteration=iter_,
timestamp=timestamp + iter_ * 1000,
msg=f"This is a log message from test task iter {iter_}",
)
# sleep so timestamp is not the same
time.sleep(0.01)
for iter_ in range(10)
]
self.send_batch(events)
data = self.api.events.get_task_log(task=task)
assert len(data["events"]) == 10
# test forward navigation
scroll_id = None
for page in range(3):
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, expected_page=page
)
self.api.tasks.reset(task=task)
data = self.api.events.get_task_log(task=task)
assert len(data["events"]) == 0
# test backwards navigation
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, navigate_earlier=False
)
# refresh
self._assert_log_events(task=task, scroll_id=scroll_id)
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
def _assert_log_events(
self,
task,
scroll_id,
batch_size: int = 5,
expected_total: int = 10,
expected_page: int = 0,
**extra_params,
):
res = self.api.events.get_task_log(
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params,
)
self.assertEqual(res.total, expected_total)
expected_events = max(
0, batch_size - max(0, (expected_page + 1) * batch_size - expected_total)
)
self.assertEqual(res.returned, expected_events)
self.assertEqual(len(res.events), expected_events)
unique_events = len({ev.iter for ev in res.events})
self.assertEqual(len(res.events), unique_events)
if res.events:
cmp_operator = operator.ge
if not extra_params.get("navigate_earlier", True):
cmp_operator = operator.le
self.assertTrue(
all(
cmp_operator(first.timestamp, second.timestamp)
for first, second in zip(res.events, res.events[1:])
)
)
return res.scroll_id
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"
@@ -393,7 +455,8 @@ class TestTaskEvents(TestService):
assert len(data["plots"]) == 0
def send_batch(self, events):
self.api.send_batch("events.add_batch", events)
_, data = self.api.send_batch("events.add_batch", events)
return data
def send(self, event):
self.api.send("events.add", event)

View File

@@ -14,9 +14,13 @@ class TestTasksDiff(TestService):
"tasks", name="test", type="testing", input=dict(view=dict()), **kwargs
)
def _compare_script(self, task, script):
for key, value in script.items():
self.assertEqual(task.script[key], value)
def _compare_script(self, task_id, script):
task = self.api.tasks.get_by_id(task=task_id).task
if not script:
self.assertFalse(task.get("script", None))
else:
for key, value in script.items():
self.assertEqual(task.script[key], value)
def test_not_deleted(self):
task_id = self.new_task()
@@ -28,11 +32,14 @@ class TestTasksDiff(TestService):
)
self.api.tasks.edit(task=task_id, script=script)
self.api.tasks.started(task=task_id)
self.api.tasks.reset(task=task_id)
task = self.api.tasks.get_by_id(task=task_id).task
self._compare_script(task, script)
self._compare_script(task_id, script)
new_reqs = dict()
self.api.tasks.set_requirements(task=task_id, requirements=new_reqs)
script["requirements"] = new_reqs
task = self.api.tasks.get_by_id(task=task_id).task
self._compare_script(task, script)
self._compare_script(task_id, script)
self.api.tasks.reset(task=task_id, clear_all=True)
self._compare_script(task_id, {})

View File

@@ -1,3 +1,4 @@
from apierrors.errors.bad_request import InvalidModelId, ValidationError
from config import config
from tests.automated import TestService
@@ -10,12 +11,37 @@ class TestTasksEdit(TestService):
super().setUp(version=2.5)
def new_task(self, **kwargs):
return self.create_temp(
"tasks", type="testing", name="test", input=dict(view=dict()), **kwargs
self.update_missing(
kwargs, type="testing", name="test", input=dict(view=dict())
)
return self.create_temp("tasks", **kwargs)
def new_model(self):
return self.create_temp("models", name="test", uri="file:///a/b", labels={})
def new_model(self, **kwargs):
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs)
def test_task_types(self):
with self.api.raises(ValidationError):
task = self.new_task(type="Unsupported")
types = ["controller", "optimizer"]
p1 = self.create_temp("projects", name="Test tasks1", description="test")
task1 = self.new_task(project=p1, type=types[0])
p2 = self.create_temp("projects", name="Test tasks2", description="test")
task2 = self.new_task(project=p2, type=types[1])
# all company types
res = self.api.tasks.get_types()
self.assertTrue(set(types).issubset(set(res["types"])))
# projects array
res = self.api.tasks.get_types(projects=[p1, p2])
self.assertEqual(set(types), set(res["types"]))
# single project
for p, t in zip((p1, p2), types):
res = self.api.tasks.get_types(projects=[p])
self.assertEqual([t], res["types"])
def test_edit_model_ready(self):
task = self.new_task()
@@ -38,6 +64,23 @@ class TestTasksEdit(TestService):
self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready)
self.api.tasks.edit(task=task, execution=dict(model=not_ready_model))
def test_task_with_model_reset(self):
# on task reset output model deleted
task = self.new_task()
self.api.tasks.started(task=task)
model_id = self.api.models.update_for_task(task=task, uri="file:///b")["id"]
self.api.tasks.reset(task=task)
with self.api.raises(InvalidModelId):
self.api.models.get_by_id(model=model_id)
# unless it is input of some task
task = self.new_task()
self.api.tasks.started(task=task)
model_id = self.api.models.update_for_task(task=task, uri="file:///b")["id"]
task_2 = self.new_task(execution=dict(model=model_id))
self.api.tasks.reset(task=task)
self.api.models.get_by_id(model=model_id)
def test_clone_task(self):
script = dict(
binary="python",
@@ -56,13 +99,13 @@ class TestTasksEdit(TestService):
new_name = "new test"
new_tags = ["by"]
execution_overrides = dict(framework="Caffe")
new_task_id = self.api.tasks.clone(
new_task_id = self._clone_task(
task=task,
new_task_name=new_name,
new_task_tags=new_tags,
execution_overrides=execution_overrides,
new_task_parent=task,
).id
)
new_task = self.api.tasks.get_by_id(task=new_task_id).task
self.assertEqual(new_task.name, new_name)
self.assertEqual(new_task.type, "testing")
@@ -73,3 +116,32 @@ class TestTasksEdit(TestService):
self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.system_tags, [])
def test_model_check_in_clone(self):
model = self.new_model()
task = self.new_task(execution=dict(model=model))
# task with deleted model still can be copied
self.api.models.delete(model=model, force=True)
self._clone_task(task=task, new_task_name="clone test")
# unless check for refs is done
with self.api.raises(InvalidModelId):
self._clone_task(
task=task, new_task_name="clone test2", validate_references=True
)
# if the model is overriden then it is always checked
with self.api.raises(InvalidModelId):
self._clone_task(
task=task,
new_task_name="clone test3",
execution_overrides=dict(model="not existing"),
)
def _clone_task(self, task, **kwargs):
new_task = self.api.tasks.clone(task=task, **kwargs).id
self.defer(
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
)
return new_task

View File

@@ -0,0 +1,89 @@
from typing import Sequence
from uuid import uuid4
from apierrors import errors
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestUsersService(TestService):
def setUp(self, version="2.8"):
super(TestUsersService, self).setUp(version=version)
self.company = self.api.users.get_current_user().user.company.id
def new_user(self):
user_name = uuid4().hex
user_id = self.api.auth.create_user(
company=self.company, name=user_name, email="{0}@{0}.com".format(user_name)
).id
self.defer(self.api.users.delete, user=user_id)
return user_id
def test_active_users(self):
user_1 = self.new_user()
user_2 = self.new_user()
user_3 = self.new_user()
model = (
self.api.impersonate(user_2)
.models.create(name="test", uri="file:///a", labels={})
.id
)
self.defer(self.api.models.delete, model=model)
project = self.create_temp("projects", name="users test", description="")
task = (
self.api.impersonate(user_3)
.tasks.create(
name="test", type="testing", input=dict(view={}), project=project
)
.id
)
self.defer(self.api.tasks.delete, task=task, move_to_trash=False)
user_ids = [user_1, user_2, user_3]
# no projects filtering
users = self.api.users.get_all_ex(id=user_ids).users
self._assertUsers((user_1, user_2, user_3), users)
# all projects
users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[]).users
self._assertUsers((user_2, user_3), users)
# specific project
users = self.api.users.get_all_ex(active_in_projects=[project]).users
self._assertUsers((user_3,), users)
def _assertUsers(self, expected: Sequence, users: Sequence):
self.assertEqual(set(expected), set(u.id for u in users))
def test_no_preferences(self):
user = self.new_user()
assert self.api.impersonate(user).users.get_preferences().preferences == {}
def _test_update(self, user, tests):
"""
Check that all for each (updates, expected_result) pair, ``updates`` yield ``result``.
"""
new_user_client = self.api.impersonate(user)
for update, expected in tests:
new_user_client.users.set_preferences(user=user, preferences=update)
preferences = new_user_client.users.get_preferences(user=user).preferences
self.assertEqual(preferences, expected)
def test_nested_update(self):
tests = [
({"a": 0}, {"a": 0}),
({"b": 1}, {"a": 0, "b": 1}),
({"section": {"a": 2}}, {"a": 0, "b": 1, "section": {"a": 2}}),
]
self._test_update(self.new_user(), tests)
def test_delete(self):
tests = [
({"section": {"a": 0, "b": 1}},) * 2,
({"section": {"a": None}}, {"section": {"a": None}}),
({"section": None}, {"section": None}),
]
self._test_update(self.new_user(), tests)

View File

@@ -0,0 +1,10 @@
from enum import Enum
class StringEnum(Enum):
def __str__(self):
return self.value
# noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values):
return name

View File

@@ -10,7 +10,7 @@ class ThreadsManager:
def __init__(self, name=None, **threads):
super(ThreadsManager, self).__init__()
self.name = name or self.__class__.name
self.name = name or self.__class__.__name__
self.objects = {}
self.lock = Lock()

View File

@@ -1 +1 @@
__version__ = "0.14.2"
__version__ = "0.15.0"