Compare commits

17 Commits

Author SHA1 Message Date
allegroai
702b6dc9c8 Version bump to v1.14.0 2024-01-10 15:31:11 +02:00
allegroai
db15f235e4 Make sure files downloaded from the apiserver are not cached by browsers 2024-01-10 15:31:01 +02:00
allegroai
8c347f8fa9 Fix include and exclude filters not processing "no tags" condition 2024-01-10 15:26:55 +02:00
allegroai
768c3d80ff Remove callback_url_prefix and state parameters from login.supported_modes and does not return urls 2024-01-10 15:26:22 +02:00
allegroai
a5c3ef6385 Fix query filter so that the default operator between different query operations on the same parameter is AND instead of OR 2024-01-10 15:24:37 +02:00
allegroai
11b7a384af Set API version 2.28 2024-01-10 15:23:54 +02:00
allegroai
9a70ade4a6 Support task models with missing model field in data_tool import 2024-01-10 15:18:58 +02:00
allegroai
91ce140901 Add "queue watched" indication to pipelines.start_pipeline 2024-01-10 15:15:43 +02:00
allegroai
49084a9c49 Optimize task statistics for projects dashboard and statistics reporter 2024-01-10 15:13:25 +02:00
allegroai
8a99eb6812 Fix model_metrics parameter name in get_multi_task_metrics schema 2024-01-10 15:12:56 +02:00
allegroai
811ab2bf4f Support exporting users with data tool 2024-01-10 15:12:07 +02:00
allegroai
3752db122b Add events.get_multi_task_metrics 2024-01-10 15:11:27 +02:00
allegroai
439911b84c Upgrade werkzeug and flask dependencies 2024-01-10 15:10:46 +02:00
allegroai
262a301e28 Check for dictionary type for some model and task fields 2024-01-10 15:10:41 +02:00
allegroai
a604451b01 Refactor check for tasks write permission 2024-01-10 15:08:20 +02:00
allegroai
88a7773621 Allow filtering on event metrics in multi-task endpoints get_task_single_value_metrics, multi_task_scalar_metrics_iter_histogram and get_multi_task_plots 2024-01-10 15:07:46 +02:00
allegroai
35c4061992 Support filtering by task or model ids in projects.get_unique_metric_variants 2024-01-10 15:06:21 +02:00
36 changed files with 1021 additions and 316 deletions

View File

@@ -41,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
)
],
)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
@@ -148,18 +149,23 @@ class MultiTasksRequestBase(Base):
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
last_iters_per_task_metric: bool = BoolField(default=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskPlotsRequest(Base):

View File

@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
class GetSupportedModesRequest(Base):
state = StringField(help_text="ASCII base64 encoded application state")
callback_url_prefix = StringField()
pass
# state = StringField(help_text="ASCII base64 encoded application state")
# callback_url_prefix = StringField()
class BasicGuestMode(Base):

View File

@@ -18,8 +18,4 @@ class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)
verify_watched_queue = fields.BoolField(default=False)

View File

@@ -33,6 +33,7 @@ class ProjectOrNoneRequest(models.Base):
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
ids = fields.ListField(str)
class GetParamsRequest(ProjectOrNoneRequest):

View File

@@ -31,6 +31,7 @@ from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.model import ModelBLL
from apiserver.bll.task.utils import get_many_tasks_for_writing
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
@@ -42,6 +43,7 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.service_repo.auth import Identity
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
@@ -55,7 +57,9 @@ MIN_LONG = -(2**63)
log = config.logger(__file__)
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
async_delete_threshold = config.get("services.tasks.async_events_delete_threshold", 100_000)
async_delete_threshold = config.get(
"services.tasks.async_events_delete_threshold", 100_000
)
class EventBLL(object):
@@ -97,7 +101,9 @@ class EventBLL(object):
return self._metrics
@staticmethod
def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set:
def _get_valid_entities(
company_id, ids: Mapping[str, bool], identity: Identity, model=False
) -> Set:
"""Verify that task or model exists and can be updated"""
if not ids:
return set()
@@ -116,20 +122,34 @@ class EventBLL(object):
):
if not requested_ids:
continue
query = Q(id__in=requested_ids, company=company_id)
res.update(
(Model if model else Task).objects(query & locked_q).scalar("id")
)
query = Q(id__in=requested_ids) & locked_q
if model:
ids = Model.objects(query & Q(company=company_id)).scalar("id")
else:
ids = {
t.id
for t in get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=query,
only=("id",),
throw_on_forbidden=False,
)
}
res.update(ids)
return res
def add_events(
self,
company_id: str,
user_id: str,
identity: Identity,
events: Sequence[dict],
worker: str,
) -> Tuple[int, int, dict]:
user_id = identity.user
task_ids = {}
model_ids = {}
for event in events:
@@ -161,8 +181,12 @@ class EventBLL(object):
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True)
valid_tasks = self._get_valid_entities(company_id, ids=task_ids)
valid_models = self._get_valid_entities(
company_id, ids=model_ids, identity=identity, model=True
)
valid_tasks = self._get_valid_entities(
company_id, ids=task_ids, identity=identity
)
actions: List[dict] = []
used_task_ids = set()

View File

@@ -21,6 +21,7 @@ from apiserver.bll.event.event_common import (
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.tools import safe_get
@@ -161,7 +162,9 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, companies: TaskCompanies
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
@@ -179,7 +182,13 @@ class EventMetrics:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
pool.map(
partial(
self._get_task_single_value_metrics,
metric_variants=metric_variants,
),
companies.items(),
)
),
)
@@ -195,19 +204,19 @@ class EventMetrics:
}
def _get_task_single_value_metrics(
self, tasks: Tuple[str, Sequence[str]]
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
) -> Sequence[dict]:
company_id, task_ids = tasks
must = [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
@@ -280,7 +289,8 @@ class EventMetrics:
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -366,7 +376,8 @@ class EventMetrics:
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -432,7 +443,9 @@ class EventMetrics:
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
cls,
task_id: str,
metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:
@@ -451,12 +464,96 @@ class EventMetrics:
return {"bool": {"must": must}}
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
"""
For the requested tasks return reported metrics and variants
"""
tasks_ids = {
company: [t.id for t in tasks]
for company, tasks in companies.items()
}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
companies_res: Sequence = list(
pool.map(
partial(
self._get_multi_task_metrics,
event_type=event_type,
),
tasks_ids.items(),
)
)
if len(companies_res) == 1:
return companies_res[0]
res = defaultdict(set)
for c_res in companies_res:
for m, vars_ in c_res.items():
res[m].update(vars_)
return {
k: list(v)
for k, v in res.items()
}
def _get_multi_task_metrics(
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
) -> Mapping[str, list]:
company_id, task_ids = company_tasks
if check_empty_data(self.es, company_id, event_type):
return {}
search_args = dict(
es=self.es,
company_id=company_id,
event_type=event_type,
)
query = QueryBuilder.terms("task", task_ids)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
}
}
}
},
}
es_res = search_company_events(
body=es_req,
**search_args,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
return {
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
for mb in aggs_result["metrics"]["buckets"]
}
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
For the requested tasks return reported metrics per task
"""
if check_empty_data(self.es, company_id, event_type):
return {}

View File

@@ -10,6 +10,7 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo.auth import Identity
from .metadata import Metadata
@@ -57,14 +58,15 @@ class ModelBLL:
cls,
model_id: str,
company_id: str,
user_id: str,
identity: Identity,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, str, bool], dict] = None,
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
user_id = identity.user
published_task = None
if model.task and publish_task_func:
task = (
@@ -74,7 +76,7 @@ class ModelBLL:
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, user_id, force_publish_task
model.task, company_id, identity, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res

View File

@@ -341,6 +341,17 @@ class ProjectBLL:
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def project_task_fields():
return {
"$project": {
"project": 1,
"status": 1,
"system_tags": 1,
"started": 1,
"completed": 1,
}
}
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
@@ -368,6 +379,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
"$group": {
@@ -516,6 +528,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
# for each project
@@ -1112,11 +1125,7 @@ class ProjectBLL:
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
op = (
Q.OR
if helper.explicit_operator and helper.global_operator == Q.OR
else Q.AND
)
op = helper.global_operator
db_query = {op: helper.actions}
else:
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
@@ -1125,7 +1134,7 @@ class ProjectBLL:
for op, actions in db_query.items():
field_conditions = {}
for action, values in actions.items():
value = list(set(values))
value = list(set(values)) if isinstance(values, list) else values
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)

View File

@@ -239,6 +239,7 @@ class ProjectQueries:
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
ids: Sequence[str],
model_metrics: bool = False,
):
pipeline = [
@@ -246,6 +247,7 @@ class ProjectQueries:
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
**({"_id": {"$in": ids}} if ids else {}),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},

View File

@@ -152,7 +152,7 @@ class QueueBLL(object):
for item in queue.entries:
try:
task = Task.get_for_writing(
task = Task.get(
company=company_id,
id=item.task,
_only=[

View File

@@ -254,6 +254,14 @@ class StatisticsReporter:
**({"last_worker": {"$in": workers}} if workers else {}),
}
},
{
"$project": {
"last_worker": 1,
"last_update": 1,
"started": 1,
"last_iteration": 1,
}
},
{
"$group": {
"_id": "$last_worker" if workers else None,

View File

@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -48,12 +49,14 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifacts = {
get_artifact_id(a): Artifact(**a)
@@ -64,18 +67,20 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifact_ids = [
get_artifact_id(a)
@@ -85,4 +90,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.service_repo.auth import Identity
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -31,7 +32,10 @@ class HyperParams:
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -63,7 +67,7 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
@@ -74,6 +78,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
with_param, without_param = iterutils.partition(
@@ -96,7 +101,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@@ -105,7 +110,7 @@ class HyperParams:
def edit_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
@@ -117,6 +122,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
update_cmds = dict()
@@ -135,7 +141,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@@ -163,7 +169,10 @@ class HyperParams:
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -209,13 +218,15 @@ class HyperParams:
def edit_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
update_cmds = dict()
configuration = {
@@ -228,22 +239,24 @@ class HyperParams:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -58,27 +58,6 @@ class TaskBLL:
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
@staticmethod
def get_by_id(
company_id,

View File

@@ -9,6 +9,7 @@ from apiserver.bll.task import (
ChangeStatusRequest,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
@@ -24,6 +25,7 @@ from apiserver.database.model.task.task import (
DEFAULT_LAST_ITERATION,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
@@ -33,7 +35,7 @@ queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task],
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
@@ -42,9 +44,10 @@ def archive_task(
Return 1 if successful
"""
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
@@ -54,8 +57,9 @@ def archive_task(
"system_tags",
"enqueue_status",
),
requires_write_access=True,
)
user_id = identity.user
try:
TaskBLL.dequeue_and_change_status(
task,
@@ -79,34 +83,34 @@ def archive_task(
def unarchive_task(
task: str,
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task,
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=("id",),
requires_write_access=True,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
last_changed_by=identity.user,
)
def dequeue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
remove_from_all_queues: bool = False,
@@ -119,7 +123,19 @@ def dequeue_task(
task = Task.get(
id=task_id,
company=company_id,
_only=(
_only=("id",),
include_public=True,
)
if not task:
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
@@ -127,11 +143,7 @@ def dequeue_task(
"project",
"enqueue_status",
),
include_public=True,
)
if not task:
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
res = TaskBLL.dequeue_and_change_status(
task,
@@ -148,7 +160,7 @@ def dequeue_task(
def enqueue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
queue_id: str,
status_message: str,
status_reason: str,
@@ -173,11 +185,11 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
user_id = identity.user
if validate:
TaskBLL.validate(task)
@@ -207,9 +219,9 @@ def enqueue_task(
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
nested_set(res, ("fields", "execution.queue"), queue_id)
return 1, res
@@ -242,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
def delete_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
@@ -251,8 +263,9 @@ def delete_task(
status_reason: str,
delete_external_artifacts: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if (
@@ -305,15 +318,16 @@ def delete_task(
def reset_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force and task.status == TaskStatus.published:
@@ -392,14 +406,15 @@ def reset_task(
def publish_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
publish_model_func: Callable[[str, str, str], Any] = None,
publish_model_func: Callable[[str, str, Identity], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force:
validate_status_change(task.status, TaskStatus.published)
@@ -422,7 +437,7 @@ def publish_task(
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id, user_id)
publish_model_func(model.id, company_id, identity)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
@@ -446,7 +461,7 @@ def publish_task(
def stop_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
user_name: str,
status_reason: str,
force: bool,
@@ -459,10 +474,11 @@ def stop_task(
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"status",
"project",
@@ -472,7 +488,6 @@ def stop_task(
"last_update",
"execution.queue",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:

View File

@@ -1,7 +1,9 @@
from datetime import datetime
from typing import Sequence
import attr
import six
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
@@ -10,6 +12,7 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@@ -157,15 +160,75 @@ def get_possible_status_changes(current_status):
return possible
def get_many_tasks_for_writing(
company_id: str,
identity: Identity,
query: Q = None,
only: Sequence = None,
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
missing = [f for f in ("company", ) if f not in only]
if missing:
only = [*only, *missing]
result = list(
Task.get_many(
company=company_id,
query=query,
override_projection=only,
allow_public=True,
return_dicts=False,
)
)
forbidden_tasks = {task.id for task in result if not task.company}
if forbidden_tasks:
if throw_on_forbidden:
raise errors.forbidden.NoWritePermission(
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
)
result = [task for task in result if task.id not in forbidden_tasks]
return result
def get_task_with_write_access(
task_id: str,
company_id: str,
identity: Identity,
only=None,
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(_only=only, **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str,
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
only=("id", "status"),
identity=identity,
)
if allow_all_statuses:
return task

View File

@@ -146,9 +146,10 @@ class GetMixin(PropsMixin):
"__$any": Q.OR,
"__$or": Q.OR,
}
default_operator = Q.OR
default_global_operator = Q.AND
default_context = Q.OR
# not_all modifier currently not supported due to the backwards compatibility
mongo_modifiers = {
# not_all modifier currently not supported due to the backwards compatibility
Q.AND: {True: "all", False: "nin"},
Q.OR: {True: "in", False: "nin"},
}
@@ -165,24 +166,22 @@ class GetMixin(PropsMixin):
self.allow_empty = False
self.global_operator = None
self.actions = defaultdict(list)
self.explicit_operator = False
self._support_legacy = legacy
current_context = self.default_operator
current_context = self.default_context
for d in self._get_next_term(data):
if d.operator is not None:
current_context = d.operator
self._support_legacy = False
if self.global_operator is None:
self.global_operator = d.operator
self.explicit_operator = True
continue
if self.global_operator is None:
self.global_operator = self.default_operator
self.global_operator = self.default_global_operator
if d.reset:
current_context = self.default_operator
current_context = self.default_context
self._support_legacy = legacy
continue
@@ -195,7 +194,7 @@ class GetMixin(PropsMixin):
)
if self.global_operator is None:
self.global_operator = self.default_operator
self.global_operator = self.default_global_operator
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
unary_operator = None
@@ -618,7 +617,20 @@ class GetMixin(PropsMixin):
):
if not vals:
continue
operations[self._db_modifiers[(op, include)]] = list(set(vals))
unique = set(vals)
if None in unique:
# noinspection PyTypeChecker
unique.remove(None)
if include:
operations["size"] = 0
else:
operations["not__size"] = 0
if not unique:
continue
operations[self._db_modifiers[(op, include)]] = list(unique)
self.db_query[op] = operations
@@ -656,7 +668,8 @@ class GetMixin(PropsMixin):
ops = []
for action, vals in actions.items():
if not vals:
# cannot just check vals here since 0 is acceptable value
if vals is None or vals == []:
continue
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
@@ -1283,21 +1296,6 @@ class GetMixin(PropsMixin):
)
return result
@classmethod
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
company=company,
*args,
**dict(return_dicts=False, **kwargs),
allow_public=True,
)
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
)
return result
class UpdateMixin(object):

View File

@@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import (
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
TaskModelNames,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -66,6 +68,7 @@ class PrePopulate:
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
users_filename = "users.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
img_source_regex = re.compile(
@@ -78,6 +81,7 @@ class PrePopulate:
project_cls: Type[Project]
model_cls: Type[Model]
user_cls: Type[User]
auth_user_cls: Type[AuthUser]
# noinspection PyTypeChecker
@classmethod
@@ -90,6 +94,8 @@ class PrePopulate:
cls.project_cls = cls._get_entity_type("database.model.project.Project")
if not hasattr(cls, "user_cls"):
cls.user_cls = cls._get_entity_type("database.model.User")
if not hasattr(cls, "auth_user_cls"):
cls.auth_user_cls = cls._get_entity_type("database.model.auth.User")
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
@@ -205,6 +211,8 @@ class PrePopulate:
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
export_events: bool = True,
export_users: bool = False,
) -> Sequence[str]:
cls._init_entity_types()
@@ -240,11 +248,15 @@ class PrePopulate:
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
if export_users:
cls._export_users(zfile)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
export_events=export_events,
cleanup_users=not export_users,
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
@@ -265,6 +277,9 @@ class PrePopulate:
metadata_hash=metadata_hash,
)
if created_files:
print("Created files:\n" + "\n".join(file for file in created_files))
return created_files
@classmethod
@@ -296,18 +311,26 @@ class PrePopulate:
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
user_mapping = cls._import_users(zfile, company_id)
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata)
cls._import(
zfile,
company_id=company_id,
user_id=user_id,
metadata=metadata,
user_mapping=user_mapping,
)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
@@ -438,7 +461,7 @@ class PrePopulate:
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
entities = defaultdict(set)
entities: Dict[Any] = defaultdict(set)
if projects:
print("Reading projects...")
@@ -497,7 +520,6 @@ class PrePopulate:
@classmethod
def _cleanup_model(cls, model: Model):
model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags)
@classmethod
@@ -505,7 +527,6 @@ class PrePopulate:
task.comment = "Auto generated by Allegro.ai"
task.status_message = ""
task.status_reason = ""
task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
@@ -513,17 +534,32 @@ class PrePopulate:
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
def _cleanup_auth_user(cls, user: AuthUser):
user.company = ""
for cred in user.credentials:
if getattr(cred, "company", None):
cred["company"] = ""
return user
@classmethod
def _cleanup_be_user(cls, user: User):
user.company = ""
user.preferences = None
return user
@classmethod
def _cleanup_entity(cls, entity_cls, entity, cleanup_users):
if cleanup_users:
entity.user = ""
if entity_cls == cls.task_cls:
cls._cleanup_task(entity)
elif entity_cls == cls.model_cls:
cls._cleanup_model(entity)
elif entity == cls.project_cls:
elif entity_cls == cls.project_cls:
cls._cleanup_project(entity)
@classmethod
@@ -633,6 +669,38 @@ class PrePopulate:
else:
print(f"Artifact {full_path} not found")
@classmethod
def _export_users(cls, writer: ZipFile):
auth_users = {
user.id: cls._cleanup_auth_user(user)
for user in cls.auth_user_cls.objects(role__in=(Role.admin, Role.user))
}
if not auth_users:
return
be_users = {
user.id: cls._cleanup_be_user(user)
for user in cls.user_cls.objects(id__in=list(auth_users))
}
if not be_users:
return
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
print(f"Writing {len(auth_users)} users into {writer.filename}")
data = {}
for field, users in (("auth", auth_users), ("backend", be_users)):
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for user in users.values():
w.write(user.to_json())
data[field] = f.getvalue()
def get_field_bytes(k: str, v: bytes) -> bytes:
return f'"{k}": '.encode("utf-8") + v
data_str = b",\n".join(get_field_bytes(k, v) for k, v in data.items())
writer.writestr(cls.users_filename, b"{\n" + data_str + b"\n}")
@classmethod
def _get_base_filename(cls, cls_: type):
name = f"{cls_.__module__}.{cls_.__name__}"
@@ -642,7 +710,13 @@ class PrePopulate:
@classmethod
def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
cls,
writer: ZipFile,
entities: dict,
hash_,
tag_entities: bool = False,
export_events: bool = True,
cleanup_users: bool = True,
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
@@ -656,18 +730,19 @@ class PrePopulate:
if not items:
continue
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
if export_events:
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
cls._cleanup_entity(cls_, item)
cls._cleanup_entity(cls_, item, cleanup_users=cleanup_users)
w.write(item.to_json())
data = f.getvalue()
hash_.update(data)
@@ -717,7 +792,10 @@ class PrePopulate:
@classmethod
def _generate_new_ids(
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
cls,
reader: ZipFile,
entity_files: Sequence,
metadata: Mapping[str, Any],
) -> Mapping[str, str]:
if not metadata or not any(
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
@@ -745,6 +823,68 @@ class PrePopulate:
)
return ids
@classmethod
def _import_users(cls, reader: ZipFile, company_id: str = "") -> dict:
"""
Import users to db and return the mapping of old user ids to the new ones
If no users were in the users file then the mapping was empty
If the user in the file has the same email as one of the existing ones then this user is skipped
and its id is mapped to the existing user with the same email
If the user with the same id exists in backend or auth db then its creation is skipped
"""
users_file = first(
fi for fi in reader.filelist if fi.orig_filename == cls.users_filename
)
if not users_file:
return {}
existing_user_ids = set(cls.user_cls.objects().scalar("id")) | set(
cls.auth_user_cls.objects().scalar("id")
)
existing_user_emails = {u.email: u.id for u in cls.auth_user_cls.objects()}
user_id_mappings = {}
with reader.open(users_file) as f:
data = json.loads(f.read())
auth_users = {u["_id"]: u for u in data["auth"]}
be_users = {u["_id"]: u for u in data["backend"]}
for uid, user in auth_users.items():
email = user.get("email")
existing_user_id = existing_user_emails.get(email)
if existing_user_id:
user_id_mappings[uid] = existing_user_id
continue
user_id_mappings[uid] = uid
if uid in existing_user_ids:
continue
credentials = user.get("credentials", [])
for c in credentials:
if c.get("company") == "":
c["company"] = company_id
if hasattr(cls.auth_user_cls, "sec_groups"):
user_role = user.get("role", Role.user)
if user_role == Role.user:
user["sec_groups"] = ["30795571-a470-4717-a80d-e8705fc776bf"]
else:
user["sec_groups"] = [
"c14a3cc6-1144-4896-8ea6-fb186ee19896",
"30795571-a470-4717-a80d-e8705fc776bf",
"30795571a4704717a80de8705897ytuyg",
]
auth_user = cls.auth_user_cls.from_json(json.dumps(user), created=True)
auth_user.company = company_id
auth_user.save()
be_user = cls.user_cls.from_json(json.dumps(be_users[uid]), created=True)
be_user.company = company_id
be_user.save()
return user_id_mappings
@classmethod
def _import(
cls,
@@ -753,6 +893,7 @@ class PrePopulate:
user_id: str = None,
metadata: Mapping[str, Any] = None,
sort_tasks_by_last_updated: bool = True,
user_mapping: Mapping[str, str] = None,
):
"""
Import entities and events from the zip file
@@ -763,7 +904,7 @@ class PrePopulate:
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
and fi.orig_filename not in (cls.metadata_filename, cls.users_filename)
]
metadata = metadata or {}
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
@@ -773,7 +914,13 @@ class PrePopulate:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(
f, full_name, company_id, user_id, metadata, old_to_new_ids
f,
full_name=full_name,
company_id=company_id,
user_id=user_id,
metadata=metadata,
old_to_new_ids=old_to_new_ids,
user_mapping=user_mapping,
)
if res:
tasks = res
@@ -794,7 +941,7 @@ class PrePopulate:
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, company_id, user_id, task.id)
cls._import_events(f, company_id, task.user, task.id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
@@ -874,7 +1021,7 @@ class PrePopulate:
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = models.get(type_, [])
new_models = [m for m in models.get(type_, []) if m.get("model") is not None]
name = TaskModelNames[type_]
if old_model and not any(
m
@@ -908,7 +1055,9 @@ class PrePopulate:
user_id: str,
metadata: Mapping[str, Any],
old_to_new_ids: Mapping[str, str] = None,
user_mapping: Mapping[str, str] = None,
) -> Optional[Sequence[Task]]:
user_mapping = user_mapping or {}
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
@@ -930,7 +1079,7 @@ class PrePopulate:
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
doc.user = user_mapping.get(doc.user, user_id) if doc.user else user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, cls.project_cls):
@@ -970,7 +1119,7 @@ class PrePopulate:
ev["allow_locked"] = True
cls.event_bll.add_events(
company_id=company_id,
user_id=user_id,
identity=Identity(user_id, company=company_id, role=Role.admin),
events=events,
worker="",
)

View File

@@ -10,7 +10,7 @@ elasticsearch==7.17.9
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=2.3.2
flask>=2.3.3
furl>=2.0.0
google-cloud-storage>=2.8.0
gunicorn>=20.1.0
@@ -34,3 +34,4 @@ setuptools>=65.5.1
six
validators>=0.12.4
urllib3>=1.26.18
werkzeug>=3.0.1

View File

@@ -754,6 +754,42 @@ get_task_metrics{
}
}
}
get_multi_task_metrics {
"2.28" {
description: """Get unique metrics and variants from the events of the specified type.
Only events reported for the passed task or model ids are analyzed."""
request {
type: object
required: [ tasks ]
properties {
tasks {
description: task ids to get metrics from
type: array
items {type: string}
}
model_events {
description: If not set or set to false then passed ids are task ids otherwise model ids
type: boolean
default: false
}
event_type {
"description": Event type. If not specified then metrics are collected from the reported events of all types
"$ref": "#/definitions/event_type_enum"
}
}
}
response {
type: object
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
get_task_log {
"1.5" {
description: "Get all 'log' events for this task"
@@ -971,10 +1007,17 @@ get_task_events {
}
}
"2.22": ${get_task_events."2.1"} {
request.properties.model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
request.properties {
model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
}
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
@@ -1156,6 +1199,13 @@ get_multi_task_plots {
default: true
}
}
"2.28": ${get_multi_task_plots."2.26"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_vector_metrics_and_variants {
"2.1" {
@@ -1342,6 +1392,13 @@ multi_task_scalar_metrics_iter_histogram {
default: false
}
}
"2.28": ${multi_task_scalar_metrics_iter_histogram."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_single_value_metrics {
"2.20" {
@@ -1369,6 +1426,13 @@ get_task_single_value_metrics {
default: false
}
}
"2.28": ${get_task_single_value_metrics."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_latest_scalar_values {
"2.1" {

View File

@@ -11,16 +11,7 @@ supported_modes {
description: """ Return supported login modes."""
request {
type: object
properties {
state {
description: "ASCII base64 encoded application state"
type: string
}
callback_url_prefix {
description: "URL prefix used to generate the callback URL for each supported SSO provider"
type: string
}
}
additionalProperties: false
}
response {
type: object

View File

@@ -79,4 +79,15 @@ start_pipeline {
}
}
}
"2.28": ${start_pipeline."2.17"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}

View File

@@ -949,6 +949,13 @@ get_unique_metric_variants {
default: false
}
}
"2.28": ${get_unique_metric_variants."2.25"} {
request.properties.ids {
description: IDs of the tasks or models to get metrics from
type: array
items {type: string}
}
}
}
get_hyperparam_values {
"2.13" {

View File

@@ -42,7 +42,10 @@ class RequestHandlers:
response = redirect(call.result.redirect.url, call.result.redirect.code)
else:
headers = None
disable_cache = False
if call.result.filename:
# make sure that downloaded files are not cached by the client
disable_cache = True
try:
call.result.filename.encode("ascii")
except UnicodeEncodeError:
@@ -61,6 +64,9 @@ class RequestHandlers:
status=call.result.code,
headers=headers,
)
if disable_cache:
response.cache_control.no_store = True
response.cache_control.max_age = 0
if call.result.cookies:
for key, value in call.result.cookies.items():

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.27")
_max_version = PartialVersion("2.28")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -31,6 +31,7 @@ from apiserver.apimodels.events import (
GetMetricSamplesRequest,
TaskMetric,
MultiTaskPlotsRequest,
MultiTaskMetricsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
@@ -38,6 +39,7 @@ from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -73,7 +75,7 @@ def add(call: APICall, company_id, _):
data = call.data.copy()
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
events=[data],
worker=call.worker,
)
@@ -88,7 +90,7 @@ def add_batch(call: APICall, company_id, _):
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
events=events,
worker=call.worker,
)
@@ -521,6 +523,7 @@ def multi_task_scalar_metrics_iter_histogram(
),
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
)
@@ -548,7 +551,8 @@ def get_task_single_value_metrics(
tasks=_get_single_value_metrics_response(
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
companies=companies,
metric_variants=_get_metric_variants_from_request(request.metrics),
),
)
)
@@ -591,10 +595,11 @@ def _get_multitask_plots(
companies: TaskCompanies,
last_iters: int,
last_iters_per_task_metric: bool,
metrics: MetricVariants = None,
request_metrics: Sequence[ApiMetrics] = None,
scroll_id=None,
no_scroll=True,
) -> Tuple[dict, int, str]:
metrics = _get_metric_variants_from_request(request_metrics)
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
@@ -629,6 +634,7 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
scroll_id=request.scroll_id,
no_scroll=request.no_scroll,
last_iters_per_task_metric=request.last_iters_per_task_metric,
request_metrics=request.metrics,
)
call.result.data = dict(
plots=return_events,
@@ -960,12 +966,38 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
}
@endpoint("events.get_multi_task_metrics")
def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsRequest):
companies = _get_task_or_model_index_companies(
company_id, request.tasks, model_events=request.model_events
)
if not companies:
return {"metrics": []}
metrics = event_bll.metrics.get_multi_task_metrics(
companies=companies,
event_type=request.event_type
)
res = [
{
"metric": m,
"variants": sorted(vars_),
}
for m, vars_ in metrics.items()
]
call.result.data = {
"metrics": sorted(res, key=itemgetter("metric"))
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, _):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked
@@ -990,7 +1022,9 @@ def delete_for_model(call: APICall, company_id: str, _):
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,

View File

@@ -28,6 +28,7 @@ from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.model import validate_id
@@ -46,6 +47,7 @@ from apiserver.database.utils import (
filter_fields,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -191,7 +193,7 @@ create_fields = {
"project": Project,
"parent": Model,
"framework": None,
"design": None,
"design": dict,
"labels": dict,
"ready": None,
"metadata": list,
@@ -249,13 +251,12 @@ def update_for_task(call: APICall, company_id, _):
)
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("models", "execution", "name", "status", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
@@ -343,7 +344,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
validate_task(company_id, call.identity, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
@@ -373,7 +374,7 @@ def prepare_update_fields(call, company_id, fields: dict):
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(company_id, fields)
validate_task(company_id, call.identity, fields)
if "labels" in fields:
labels = fields["labels"]
@@ -403,8 +404,11 @@ def prepare_update_fields(call, company_id, fields: dict):
return fields
def validate_task(company_id, fields: dict):
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
def validate_task(company_id: str, identity: Identity, fields: dict):
task_id = fields["task"]
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
)
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
@@ -514,7 +518,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
)
@@ -533,7 +537,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
func=partial(
ModelBLL.publish_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),

View File

@@ -5,22 +5,24 @@ import attr
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
from apiserver.apimodels.pipelines import (
StartPipelineResponse,
StartPipelineRequest,
DeleteRunsRequest,
)
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task, delete_task
from apiserver.bll.util import run_batch_operation
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
queue_bll = QueueBLL()
def _update_task_name(task: Task):
@@ -57,7 +59,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=False,
force=True,
return_file_urls=False,
@@ -79,9 +81,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
call.result.data = dict(succeeded=succeeded, failed=failures)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
@endpoint("pipelines.start_pipeline")
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
@@ -108,10 +108,19 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",
)
extra = {}
if request.verify_watched_queue and queued:
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
call.result.data = dict(
pipeline=task.id,
enqueued=bool(queued),
**extra,
)

View File

@@ -380,6 +380,7 @@ def get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
ids=request.ids,
model_metrics=request.model_metrics,
)

View File

@@ -19,7 +19,9 @@ from apiserver.apimodels.reports import (
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.database.model.model import Model
from apiserver.service_repo.auth import Identity
from apiserver.services.models import conform_model_data
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
@@ -57,15 +59,15 @@ update_fields = {
}
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
def _assert_report(company_id: str, task_id: str, identity: Identity, only_fields=None):
if only_fields and "type" not in only_fields:
only_fields += ("type",)
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=identity,
only=only_fields,
requires_write_access=requires_write_access,
)
if task.type != TaskType.report:
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
@@ -78,6 +80,7 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task,
company_id=company_id,
identity=call.identity,
only_fields=("status",),
)
@@ -265,7 +268,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
res["plots"] = _get_multitask_plots(
companies=companies,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
request_metrics=request.plots.metrics,
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
)[0]
@@ -302,6 +305,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
task = _assert_report(
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
user_id = call.identity.user
@@ -337,7 +341,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
updates = ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
@@ -352,7 +358,9 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
@endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
archived = task.update(
status_message=request.message,
status_reason="",
@@ -366,7 +374,9 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
unarchived = task.update(
status_message=request.message,
status_reason="",
@@ -394,6 +404,7 @@ def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
if (

View File

@@ -100,10 +100,17 @@ from apiserver.bll.task.task_operations import (
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.task.utils import (
update_task,
get_task_for_update,
deleted_prefix,
get_many_tasks_for_writing,
get_task_with_write_access,
)
from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
@@ -118,6 +125,7 @@ from apiserver.database.utils import (
get_options,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -142,14 +150,34 @@ org_bll = OrgBLL()
project_bll = ProjectBLL()
def _assert_writable_tasks(
company_id: str, identity: Identity, ids: Sequence[str], only=("id",)
) -> Sequence[Task]:
tasks = get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=Q(id__in=ids),
only=only,
)
missing_ids = set(ids) - {t.id for t in tasks}
if missing_ids:
raise errors.bad_request.InvalidTaskId(ids=list(missing_ids))
return tasks
def set_task_status_from_call(
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
request: UpdateRequest,
company_id: str,
identity: Identity,
new_status=None,
**set_fields,
) -> dict:
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
request.task,
company_id=company_id,
identity=identity,
only=("id", "status", "project"),
requires_write_access=True,
)
status_reason = request.status_reason
@@ -161,15 +189,17 @@ def set_task_status_from_call(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
user_id=identity.user,
).execute(**set_fields)
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, allow_public=True
)
def get_by_id(call: APICall, company_id, request: TaskRequest):
task = TaskBLL.assert_exists(
company_id,
task_ids=request.task,
allow_public=True,
)[0]
task_dict = task.to_proper_dict()
conform_task_data(call, task_dict)
call.result.data = {"task": task_dict}
@@ -227,7 +257,9 @@ def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call.data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
company=company_id,
query_dict=call_data,
allow_public=True,
)
conform_task_data(call, tasks)
@@ -278,7 +310,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
**stop_task(
task_id=req_model.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
@@ -296,7 +328,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
func=partial(
stop_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
@@ -319,7 +351,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.stopped,
completed=datetime.utcnow(),
)
@@ -336,7 +368,7 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
)
@@ -353,7 +385,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.failed,
)
)
@@ -367,7 +399,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.closed,
)
)
@@ -381,18 +413,19 @@ create_fields = {
"error": None,
"comment": None,
"parent": Task,
"project": None,
"project": Project,
"input": None,
"models": None,
"container": None,
"container": dict,
"output_dest": None,
"execution": None,
"hyperparams": None,
"configuration": None,
"hyperparams": dict,
"configuration": dict,
"script": None,
"runtime": None,
"runtime": dict,
}
dict_fields_paths = [("execution", "model_labels"), "container"]
@@ -433,13 +466,17 @@ def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
for data in tasks_data:
params_unprepare_from_saved(
fields=data, copy_to_legacy=need_legacy_params,
fields=data,
copy_to_legacy=need_legacy_params,
)
artifacts_unprepare_from_saved(fields=data)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
call: APICall,
valid_fields=None,
output=None,
previous_task: Task = None,
):
valid_fields = valid_fields if valid_fields is not None else create_fields
t_fields = task_fields
@@ -566,11 +603,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
task_id = req_model.task
with translate_errors_context():
task = Task.get_for_writing(
id=task_id, company=company_id, _only=["id", "project"]
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("id", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
@@ -582,7 +620,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
id=task_id,
partial_update_dict=partial_update_dict,
injected_update=dict(
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
),
)
if updated_count:
@@ -606,11 +645,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
requirements = req_model.requirements
with translate_errors_context():
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
req_model.task,
company_id=company_id,
identity=call.identity,
only=("status", "script"),
requires_write_access=True,
)
if not task.script:
raise errors.bad_request.MissingTaskFields(
@@ -636,8 +675,11 @@ def update_batch(call: APICall, company_id, _):
items = {i["task"]: i for i in items}
tasks = {
t.id: t
for t in Task.get_many_for_writing(
company=company_id, query=Q(id__in=list(items))
for t in _assert_writable_tasks(
identity=call.identity,
company_id=company_id,
ids=list(items),
only=("id", "project"),
)
}
@@ -656,7 +698,8 @@ def update_batch(call: APICall, company_id, _):
if not partial_update_dict:
continue
partial_update_dict.update(
last_change=now, last_changed_by=call.identity.user,
last_change=now,
last_changed_by=call.identity.user,
)
update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
@@ -690,9 +733,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
force = req_model.force
with translate_errors_context():
task = Task.get_for_writing(id=task_id, company=company_id)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
)
if not force and task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
@@ -756,7 +801,8 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
@endpoint(
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
"tasks.get_hyper_params",
request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
@@ -771,7 +817,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
@@ -785,7 +831,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
@@ -794,7 +840,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
"tasks.get_configurations",
request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
tasks_params = HyperParams.get_configurations(
@@ -809,7 +856,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
@endpoint(
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
"tasks.get_configuration_names",
request_data_model=GetConfigurationNamesRequest,
)
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
@@ -830,7 +878,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
@@ -846,7 +894,7 @@ def delete_configuration(
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
force=request.force,
@@ -863,7 +911,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -888,7 +936,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
func=partial(
enqueue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -915,13 +963,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
@endpoint(
"tasks.dequeue", response_data_model=DequeueResponse,
"tasks.dequeue",
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, request: DequeueRequest):
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
@@ -931,14 +980,15 @@ def dequeue(call: APICall, company_id, request: DequeueRequest):
@endpoint(
"tasks.dequeue_many", response_data_model=DequeueManyResponse,
"tasks.dequeue_many",
response_data_model=DequeueManyResponse,
)
def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
results, failures = run_batch_operation(
func=partial(
dequeue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
@@ -962,7 +1012,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -990,7 +1040,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
func=partial(
reset_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -1027,9 +1077,11 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
response_data_model=ArchiveResponse,
)
def archive(call: APICall, company_id, request: ArchiveRequest):
tasks = TaskBLL.assert_exists(
archived = 0
tasks = _assert_writable_tasks(
company_id,
task_ids=request.tasks,
call.identity,
ids=request.tasks,
only=(
"id",
"company",
@@ -1040,11 +1092,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"enqueue_status",
),
)
archived = 0
for task in tasks:
archived += archive_task(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -1063,7 +1114,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
archive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1085,7 +1136,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
unarchive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1104,7 +1155,7 @@ def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1126,7 +1177,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1164,7 +1215,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason,
@@ -1183,7 +1234,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
func=partial(
publish_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
@@ -1211,7 +1262,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
**set_task_status_from_call(
request,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.completed,
completed=datetime.utcnow(),
)
@@ -1221,7 +1272,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
@@ -1256,7 +1307,7 @@ def add_or_update_artifacts(
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifacts=request.artifacts,
force=True,
@@ -1273,7 +1324,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifact_ids=request.artifacts,
force=True,
@@ -1310,6 +1361,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required"
)
_assert_writable_tasks(company_id, call.identity, request.ids)
updated_projects = set(
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
)
@@ -1330,7 +1382,8 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("tasks.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
_assert_writable_tasks(company_id, call.identity, request.ids)
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
@@ -1344,7 +1397,9 @@ def update_tags(_, company_id: str, request: UpdateTagsRequest):
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
models_field = f"models__{request.type}"
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
@@ -1364,7 +1419,9 @@ def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelR
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
task = get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
delete_names = {
type_: [m.name for m in request.models if m.type == type_]
@@ -1377,6 +1434,8 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
}
updated = task.update(
last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
**commands,
)
return {"updated": updated}

View File

@@ -3,6 +3,30 @@ from apiserver.tests.automated import TestService
class TestGetAllExFilters(TestService):
def test_no_tags_filter(self):
task = self._temp_task(tags=["test"])
task_no_tags = self._temp_task()
tasks = [task, task_no_tags]
for cond, op, tags, expected_tasks in (
("any", "include", [None], [task_no_tags]),
("any", "include", ["test"], [task]),
("any", "include", ["test", None], [task, task_no_tags]),
("any", "exclude", [None], [task]),
("any", "exclude", ["test"], [task_no_tags]),
("any", "exclude", ["test", None], [task, task_no_tags]),
("all", "include", [None], [task_no_tags]),
("all", "include", ["test"], [task]),
("all", "include", ["test", None], []),
("all", "exclude", [None], [task]),
("all", "exclude", ["test"], [task_no_tags]),
("all", "exclude", ["test", None], []),
):
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {cond: {op: tags}}}
).tasks
self.assertEqual({t.id for t in res}, set(expected_tasks))
def test_list_filters(self):
tags = ["a", "b", "c", "d"]
tasks = [self._temp_task(tags=tags[:i]) for i in range(len(tags) + 1)]

View File

@@ -37,29 +37,44 @@ class TestPipelines(TestService):
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
pipeline_task = res.pipeline
try:
self.assertTrue(res.enqueued)
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
self.assertTrue(pipeline.name.startswith(task_name))
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
)
finally:
self.api.tasks.delete(task=pipeline_task, force=True)
self.assertTrue(res.enqueued)
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
self.assertTrue(pipeline.name.startswith(task_name))
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
)
# watched queue
queue = self._temp_queue("test pipelines")
project, task = self._temp_project_and_task(name="pipelines test1")
res = self.api.pipelines.start_pipeline(
task=task, queue=queue, verify_watched_queue=True
)
self.assertEqual(res.queue_watched, False)
self.api.workers.register(worker="test pipelines", queues=[queue])
project, task = self._temp_project_and_task(name="pipelines test2")
res = self.api.pipelines.start_pipeline(
task=task, queue=queue, verify_watched_queue=True
)
self.assertEqual(res.queue_watched, True)
def _temp_project_and_task(self, name) -> Tuple[str, str]:
project = self.create_temp(
"projects", name=name, description="test", delete_params=dict(force=True),
"projects",
name=name,
description="test",
delete_params=dict(force=True, delete_contents=True),
)
return (
@@ -72,3 +87,6 @@ class TestPipelines(TestService):
system_tags=["pipeline"],
),
)
def _temp_queue(self, queue_name, **kwargs):
return self.create_temp("queues", name=queue_name, **kwargs)

View File

@@ -16,10 +16,18 @@ class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
default_task_name = "test task events"
def _temp_task(self, name=default_task_name):
task_input = dict(name=name, type="training",)
def _temp_project(self, name=default_task_name):
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input
"projects",
name=name,
description="test",
delete_params=self.delete_params,
)
def _temp_task(self, name=default_task_name, **kwargs):
self.update_missing(kwargs, name=name, type="training")
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **kwargs
)
def _temp_model(self, name="test model events", **kwargs):
@@ -62,6 +70,26 @@ class TestTaskEvents(TestService):
self._assert_task_metrics(tasks, "log")
self._assert_task_metrics(tasks, "training_stats_scalar")
self._assert_multitask_metrics(
tasks=list(tasks), metrics=["Metric1", "Metric2", "Metric3"]
)
self._assert_multitask_metrics(
tasks=list(tasks),
event_type="training_debug_image",
metrics=["Metric1", "Metric2", "Metric3"],
)
self._assert_multitask_metrics(tasks=list(tasks), event_type="plot", metrics=[])
def _assert_multitask_metrics(
self, tasks: Sequence[str], metrics: Sequence[str], event_type: str = None
):
res = self.api.events.get_multi_task_metrics(
tasks=tasks,
**({"event_type": event_type} if event_type else {}),
).metrics
self.assertEqual([r.metric for r in res], metrics)
self.assertTrue(all(r.variants == ["Test variant"] for r in res))
def _assert_task_metrics(self, tasks: dict, event_type: str):
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
for task, metrics in tasks.items():
@@ -122,6 +150,15 @@ class TestTaskEvents(TestService):
self.assertEqual(value.metric, metric)
self.assertEqual(value.variant, variant)
self.assertEqual(value.value, 0)
# test metrics parameter
res = self.api.events.get_task_single_value_metrics(
tasks=[task], metrics=[{"metric": metric, "variants": [variant]}]
).tasks
self.assertEqual(len(res), 1)
res = self.api.events.get_task_single_value_metrics(
tasks=[task], metrics=[{"metric": "non_existing", "variants": [variant]}]
).tasks
self.assertEqual(len(res), 0)
# update is working
task_data = self.api.tasks.get_by_id(task=task).task
@@ -248,6 +285,15 @@ class TestTaskEvents(TestService):
self._assert_log_events(task=task, expected_total=1)
metrics = self.api.events.get_multi_task_metrics(
tasks=[model],
event_type="training_stats_scalar",
model_events=True,
).metrics
self.assertEqual([m.metric for m in metrics], [f"Metric{i}" for i in range(5)])
variants = [f"Variant{i}" for i in range(5)]
self.assertTrue(all(m.variants == variants for m in metrics))
def test_error_events(self):
task = self._temp_task()
events = [
@@ -340,6 +386,30 @@ class TestTaskEvents(TestService):
else (None, None)
)
def test_task_unique_metric_variants(self):
project = self._temp_project()
task1 = self._temp_task(project=project)
task2 = self._temp_task(project=project)
metric1 = "Metric1"
metric2 = "Metric2"
events = [
{
**self._create_task_event("training_stats_scalar", task, 0),
"metric": metric,
"variant": "Variant",
"value": 10,
}
for task, metric in ((task1, metric1), (task2, metric2))
]
self.send_batch(events)
metrics = self.api.projects.get_unique_metric_variants(project=project).metrics
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
metrics = self.api.projects.get_unique_metric_variants(ids=[task1, task2]).metrics
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
metrics = self.api.projects.get_unique_metric_variants(ids=[task1]).metrics
self.assertEqual([m.metric for m in metrics], [metric1])
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"
variant = "Variant1"
@@ -395,6 +465,25 @@ class TestTaskEvents(TestService):
iterations=iter_count,
)
# test metrics
data = self.api.events.multi_task_scalar_metrics_iter_histogram(
tasks=tasks,
metrics=[
{
"metric": f"Metric{m_idx}",
"variants": [f"Variant{v_idx}" for v_idx in range(4)],
}
for m_idx in range(2)
],
)
self._assert_metrics_and_variants(
data.metrics,
metrics=2,
variants=4,
tasks=tasks,
iterations=iter_count,
)
def _assert_metrics_and_variants(
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
):
@@ -515,6 +604,13 @@ class TestTaskEvents(TestService):
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
# test metrics
plots = self.api.events.get_multi_task_plots(
tasks=[task1, task2], metrics=[{"metric": "A"}]
).plots
self.assertEqual(len(plots), 1)
self.assertEqual(len(plots.A), 2)
def test_task_plots(self):
task = self._temp_task()
event = self._create_task_event("plot", task, 0)

View File

@@ -1 +1 @@
__version__ = "1.13.0"
__version__ = "1.14.0"

View File

@@ -1,8 +1,9 @@
boltons>=19.1.0
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=2.3.2
flask>=2.3.3
gunicorn>=20.1.0
pyhocon>=0.3.35
setuptools>=65.5.1
urllib3>=1.26.18
urllib3>=1.26.18
werkzeug>=3.0.1