Model events are fully supported

This commit is contained in:
allegroai 2023-05-25 19:17:40 +03:00
parent 2e4e060a82
commit 58465fbc17
19 changed files with 341 additions and 228 deletions

View File

@ -29,6 +29,10 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True)
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)

View File

@ -66,6 +66,7 @@ class GetTasksDataRequest(Base):
plots: EventsRequest = EmbeddedField(EventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram)
allow_public = BoolField(default=True)
model_events: bool = BoolField(default=False)
class GetAllRequest(Base):

View File

@ -30,6 +30,7 @@ from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIt
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.model import ModelBLL
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
@ -250,7 +251,6 @@ class EventBLL(object):
task_or_model_ids.add(task_or_model_id)
if (
iter is not None
and not model_events
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_or_model_id] = max(
@ -263,8 +263,7 @@ class EventBLL(object):
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
last_events=task_last_scalar_events[task_or_model_id], event=event,
)
actions.append(es_action)
@ -303,12 +302,21 @@ class EventBLL(object):
else:
errors_per_type["Error when indexing events batch"] += 1
if not model_events:
remaining_tasks = set()
now = datetime.utcnow()
for task_or_model_id in task_or_model_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
if model_events:
ModelBLL.update_statistics(
company_id=company_id,
model_id=task_or_model_id,
last_iteration_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
)
else:
updated = self._update_task(
company_id=company_id,
task_id=task_or_model_id,
@ -319,7 +327,6 @@ class EventBLL(object):
),
last_events=task_last_events.get(task_or_model_id),
)
if not updated:
remaining_tasks.add(task_or_model_id)
continue
@ -484,7 +491,9 @@ class EventBLL(object):
)
def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.event_id_fields if field in event)
id_values = (
str(event[field]) for field in self.event_id_fields if field in event
)
return hashlib.md5("-".join(id_values).encode()).hexdigest()
def scroll_task_events(
@ -556,9 +565,7 @@ class EventBLL(object):
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type,
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@ -586,7 +593,7 @@ class EventBLL(object):
"events": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": last_iterations_per_plot
"size": last_iterations_per_plot,
}
}
},
@ -597,11 +604,7 @@ class EventBLL(object):
}
with translate_errors_context():
es_response = search_company_events(
body=es_req,
ignore=404,
**search_args,
)
es_response = search_company_events(body=es_req, ignore=404, **search_args)
aggs_result = es_response.get("aggregations")
if not aggs_result:
@ -614,9 +617,7 @@ class EventBLL(object):
for hit in variants_bucket["events"]["hits"]["hits"]
]
self.uncompress_plots(events)
return TaskEventsResult(
events=events, total_events=len(events)
)
return TaskEventsResult(events=events, total_events=len(events))
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
"""
@ -731,12 +732,7 @@ class EventBLL(object):
if not company_ids:
return TaskEventsResult()
task_ids = (
[task_id]
if isinstance(task_id, str)
else task_id
)
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = []
if metrics:
@ -967,7 +963,7 @@ class EventBLL(object):
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None
metrics: MetricVariants = None,
) -> Mapping[str, Sequence]:
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [

View File

@ -5,7 +5,7 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
@ -28,11 +28,7 @@ class ModelBLL:
@staticmethod
def assert_exists(
company_id,
model_ids,
only=None,
allow_public=False,
return_models=True,
company_id, model_ids, only=None, allow_public=False, return_models=True,
) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids)
@ -179,12 +175,36 @@ class ModelBLL:
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{
"$project": {"labels_count": 1},
},
{"$project": {"labels_count": 1}},
]
)
return {
r.pop("_id"): r
for r in result
}
return {r.pop("_id"): r for r in result}
@staticmethod
def update_statistics(
company_id: str,
model_id: str,
last_iteration_max: int = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
):
updates = {"last_update": datetime.utcnow()}
if last_iteration_max is not None:
updates.update(max__last_iteration=last_iteration_max)
raw_updates = {}
if last_scalar_events is not None:
raw_updates = {}
if last_scalar_events is not None:
get_last_metric_updates(
task_id=model_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=updates,
model_events=True,
)
ret = Model.objects(id=model_id).update_one(**updates)
if ret and raw_updates:
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
return ret

View File

@ -209,7 +209,11 @@ class ProjectQueries:
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
model_metrics: bool = False,
):
pipeline = [
{
@ -246,7 +250,8 @@ class ProjectQueries:
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
entity_cls = Model if model_metrics else Task
result = entity_cls.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod

View File

@ -40,6 +40,7 @@ from .utils import (
ChangeStatusRequest,
update_project_time,
deleted_prefix,
get_last_metric_updates,
)
log = config.logger(__file__)
@ -412,77 +413,12 @@ class TaskBLL:
raw_updates = {}
if last_scalar_events is not None:
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values-to_add}__exists"] = True
task = Task.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
get_last_metric_updates(
task_id=task_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=extra_updates,
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}",]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = (
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
)
if max_values:
if (
len(total_metrics) >= max_values
and metric not in total_metrics
):
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics
if last_events is not None:

View File

@ -5,7 +5,9 @@ import attr
import six
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
@ -189,9 +191,88 @@ def get_task_for_update(
return task
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
def update_task(
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
):
now = datetime.utcnow()
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)
def get_last_metric_updates(
task_id: str,
last_scalar_events: dict,
raw_updates: dict,
extra_updates: dict,
model_events: bool = False,
):
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values - to_add}__exists"] = True
db_cls = Model if model_events else Task
task = db_cls.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
if max_values:
if len(total_metrics) >= max_values and metric not in total_metrics:
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics

View File

@ -3,6 +3,8 @@ from mongoengine import (
DateTimeField,
BooleanField,
EmbeddedDocumentField,
IntField,
ListField,
)
from apiserver.database import Database, strict
@ -17,12 +19,14 @@ from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.project import Project
from apiserver.database.model.task.metrics import MetricEvent
from apiserver.database.model.task.task import Task
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
}
meta = {
@ -67,6 +71,7 @@ class Model(AttributedDocument):
"parent",
"metadata.*",
),
range_fields=("last_metrics.*", "last_iteration"),
datetime_fields=("last_update",),
)
@ -92,6 +97,9 @@ class Model(AttributedDocument):
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)
last_iteration = IntField(default=0)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
def get_index_company(self) -> str:
return self.company or self.company_origin or ""

View File

@ -1,6 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_common.conf"
include "_tasks_common.conf"
multi_field_pattern_data {
type: object
properties {
@ -104,6 +104,17 @@ _definitions {
"$ref": "#/definitions/metadata_item"
}
}
last_iteration {
description: "Last iteration reported for this model"
type: integer
}
last_metrics {
description: "Last metric variants (hash to events), one for each metric hash"
type: object
additionalProperties {
"$ref": "#/definitions/last_metrics_variants"
}
}
stats {
description: "Model statistics"
type: object

View File

@ -898,6 +898,13 @@ get_unique_metric_variants {
}
}
}
"999.0": ${get_unique_metric_variants."2.13"} {
request.properties.model_metrics {
description: If set to true then bring unique metric and variant names from the project models otherwise from the project tasks
type: boolean
default: false
}
}
}
get_hyperparam_values {
"2.13" {

View File

@ -568,6 +568,13 @@ get_task_data {
}
}
}
"999.0": ${get_task_data."2.23"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_all_ex {
"2.23" {

View File

@ -561,7 +561,6 @@ def _get_multitask_plots(
metrics: MetricVariants = None,
scroll_id=None,
no_scroll=True,
model_events=False,
) -> Tuple[dict, int, str]:
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
@ -598,7 +597,6 @@ def get_multi_task_plots(call, company_id, _):
last_iters=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
model_events=model_events,
)
call.result.data = dict(
plots=return_events,

View File

@ -1,6 +1,6 @@
from datetime import datetime
from functools import partial
from typing import Sequence
from typing import Sequence, Union
from mongoengine import Q, EmbeddedDocument
@ -59,6 +59,11 @@ org_bll = OrgBLL()
project_bll = ProjectBLL()
def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
conform_output_tags(call, model_data)
unescape_metadata(call, model_data)
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
@ -74,8 +79,7 @@ def get_by_id(call: APICall, company_id, _):
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
conform_output_tags(call, models[0])
unescape_metadata(call, models[0])
conform_model_data(call, models[0])
call.result.data = {"model": models[0]}
@ -102,8 +106,7 @@ def get_by_task_id(call: APICall, company_id, _):
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
unescape_metadata(call, model_dict)
conform_model_data(call, model_dict)
call.result.data = {"model": model_dict}
@ -119,8 +122,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
allow_public=request.allow_public,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
conform_model_data(call, models)
if not request.include_stats:
call.result.data = {"models": models, **ret_params}
@ -142,8 +144,7 @@ def get_by_id_ex(call: APICall, company_id, _):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
unescape_metadata(call, models)
conform_model_data(call, models)
call.result.data = {"models": models}
@ -159,8 +160,7 @@ def get_all(call: APICall, company_id, _):
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
conform_model_data(call, models)
call.result.data = {"models": models, **ret_params}
@ -428,8 +428,7 @@ def edit(call: APICall, company_id, _):
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(company_id, project=model.project, fields=fields)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
conform_model_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
@ -461,8 +460,7 @@ def _update_model(call: APICall, company_id, model_id=None):
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
conform_output_tags(call, updated_fields)
unescape_metadata(call, updated_fields)
conform_model_data(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)

View File

@ -15,10 +15,10 @@ from apiserver.apimodels.projects import (
DeleteRequest,
MoveRequest,
MergeRequest,
ProjectOrNoneRequest,
ProjectRequest,
ProjectModelMetadataValuesRequest,
ProjectChildrenType,
GetUniqueMetricsRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
@ -345,16 +345,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
@endpoint(
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
"projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest
)
def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
call: APICall, company_id: str, request: GetUniqueMetricsRequest,
):
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
model_metrics=request.model_metrics,
)
call.result.data = {"metrics": metrics}

View File

@ -1,3 +1,5 @@
from typing import Union, Sequence
from mongoengine import Q
from apiserver.apimodels.base import UpdateResponse
@ -39,14 +41,18 @@ worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll)
def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
conform_output_tags(call, queue_data)
unescape_metadata(call, queue_data)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
conform_queue_data(call, queue_dict)
call.result.data = {"queue": queue_dict}
@ -85,8 +91,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
conform_queue_data(call, queues)
call.result.data = {"queues": queues, **ret_params}
@ -102,8 +107,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest):
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
conform_queue_data(call, queues)
call.result.data = {"queues": queues, **ret_params}
@ -135,8 +139,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
conform_queue_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)

View File

@ -17,6 +17,8 @@ from apiserver.apimodels.reports import (
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.database.model.model import Model
from apiserver.services.models import conform_model_data
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
@ -35,7 +37,7 @@ from apiserver.services.events import (
from apiserver.services.tasks import (
escape_execution_parameters,
_hidden_query,
unprepare_from_saved,
conform_task_data,
)
org_bll = OrgBLL()
@ -178,7 +180,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllRequest):
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@ -198,18 +200,25 @@ def _get_task_metrics_from_request(
@endpoint("reports.get_task_data")
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
if request.model_events:
entity_cls = Model
conform_data = conform_model_data
else:
entity_cls = Task
conform_data = conform_task_data
call_data = escape_execution_parameters(call)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
tasks = entity_cls.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_data(call, tasks)
res = {"tasks": tasks, **ret_params}
if not (
request.debug_images or request.plots or request.scalar_metrics_iter_histogram
@ -217,7 +226,9 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
return res
task_ids = [task["id"] for task in tasks]
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
companies = _get_task_or_model_index_companies(
company_id, task_ids=task_ids, model_events=request.model_events
)
if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events(
companies={

View File

@ -182,7 +182,7 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
req_model.task, company_id=company_id, allow_public=True
)
task_dict = task.to_proper_dict()
unprepare_from_saved(call, task_dict)
conform_task_data(call, task_dict)
call.result.data = {"task": task_dict}
@ -231,7 +231,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq):
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@ -245,7 +245,7 @@ def get_by_id_ex(call: APICall, company_id, _):
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks}
@ -264,7 +264,7 @@ def get_all(call: APICall, company_id, _):
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@ -430,7 +430,7 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
return fields
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
@ -608,7 +608,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
company_id, project=task.project, fields=updated_fields
)
update_project_time(updated_fields.get("project"))
unprepare_from_saved(call, updated_fields)
conform_task_data(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@ -763,7 +763,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
company_id, project=task.project, fields=fixed_fields
)
update_project_time(fields.get("project"))
unprepare_from_saved(call, fields)
conform_task_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)

View File

@ -113,7 +113,13 @@ class TestReports(TestService):
def test_reports_task_data(self):
report_task = self._temp_report(name="Rep1")
for model_events in (False, True):
if model_events:
non_report_task = self._temp_model(name="hello")
event_args = {"model_event": True}
else:
non_report_task = self._temp_task(name="hello")
event_args = {}
debug_image_events = [
self._create_task_event(
task=non_report_task,
@ -122,6 +128,7 @@ class TestReports(TestService):
metric=f"Metric_{m}",
variant=f"Variant_{v}",
url=f"{m}_{v}",
**event_args,
)
for m in range(2)
for v in range(2)
@ -134,6 +141,7 @@ class TestReports(TestService):
metric=f"Metric_{m}",
variant=f"Variant_{v}",
plot_str=f"Hello plot",
**event_args,
)
for m in range(2)
for v in range(2)
@ -141,7 +149,7 @@ class TestReports(TestService):
self.send_batch([*debug_image_events, *plot_events])
res = self.api.reports.get_task_data(
id=[non_report_task], only_fields=["name"],
id=[non_report_task], only_fields=["name"], model_events=model_events
)
self.assertEqual(len(res.tasks), 1)
self.assertEqual(res.tasks[0].id, non_report_task)
@ -152,6 +160,7 @@ class TestReports(TestService):
only_fields=["name"],
debug_images={"metrics": []},
plots={"metrics": [{"metric": "Metric_1"}]},
model_events=model_events,
)
self.assertEqual(len(res.debug_images), 1)
task_events = res.debug_images[0]
@ -185,12 +194,14 @@ class TestReports(TestService):
**kwargs,
}
delete_params = {"force": True}
def _temp_report(self, name, **kwargs):
return self.create_temp(
"reports",
name=name,
object_name="task",
delete_params={"force": True},
delete_params=self.delete_params,
**kwargs,
)
@ -199,10 +210,16 @@ class TestReports(TestService):
"tasks",
name=name,
type="training",
delete_params={"force": True},
delete_params=self.delete_params,
**kwargs,
)
def _temp_model(self, name="test model events", **kwargs):
self.update_missing(
kwargs, name=name, uri="file:///a/b", labels={}, ready=False
)
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)
return data

View File

@ -16,9 +16,7 @@ class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training",
)
task_input = dict(name=name, type="training",)
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input
)
@ -220,6 +218,17 @@ class TestTaskEvents(TestService):
self.assertEqual(variant_data.x, [0, 1])
self.assertEqual(variant_data.y, [0.0, 1.0])
model_data = self.api.models.get_all_ex(
id=[model], only_fields=["last_metrics", "last_iteration"]
).models[0]
metric_data = first(first(model_data.last_metrics.values()).values())
self.assertEqual(1, model_data.last_iteration)
self.assertEqual(1, metric_data.value)
self.assertEqual(1, metric_data.max_value)
self.assertEqual(1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
def test_error_events(self):
task = self._temp_task()
events = [