from collections import OrderedDict from datetime import datetime from operator import attrgetter from random import random from time import sleep from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict import dpath import pymongo.results import six from mongoengine import Q from six import string_types import database.utils as dbutils import es_factory from apierrors import errors from apimodels.tasks import Artifact as ApiArtifact from bll.organization import OrgBLL, Tags from config import config from database.errors import translate_errors_context from database.model.model import Model from database.model.project import Project from database.model.task.metrics import EventStats, MetricEventStats from database.model.task.output import Output from database.model.task.task import ( Task, TaskStatus, TaskStatusMessage, TaskSystemTags, ArtifactModes, Artifact, external_task_types, ) from database.utils import get_company_or_none_constraint, id as create_id from service_repo import APICall from timing_context import TimingContext from utilities.dicts import deep_merge from utilities.parameter_key_escaper import ParameterKeyEscaper from .param_utils import params_prepare_for_save from .utils import ChangeStatusRequest, validate_status_change log = config.logger(__file__) org_bll = OrgBLL() class TaskBLL(object): def __init__(self, events_es=None): self.events_es = ( events_es if events_es is not None else es_factory.connect("events") ) @classmethod def get_types(cls, company, project_ids: Optional[Sequence]) -> set: """ Return the list of unique task types used by company and public tasks If project ids passed then only tasks from these projects are considered """ query = get_company_or_none_constraint(company) if project_ids: query &= Q(project__in=project_ids) res = Task.objects(query).distinct(field="type") return set(res).intersection(external_task_types) @staticmethod def get_task_with_access( task_id, company_id, only=None, allow_public=False, requires_write_access=False ) -> Task: """ Gets a task that has a required write access :except errors.bad_request.InvalidTaskId: if the task is not found :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified """ with translate_errors_context(): query = dict(id=task_id, company=company_id) with TimingContext("mongo", "task_with_access"): if requires_write_access: task = Task.get_for_writing(_only=only, **query) else: task = Task.get(_only=only, **query, include_public=allow_public) if not task: raise errors.bad_request.InvalidTaskId(**query) return task @staticmethod def get_by_id( company_id, task_id, required_status=None, only_fields=None, allow_public=False, ): if only_fields: if isinstance(only_fields, string_types): only_fields = [only_fields] else: only_fields = list(only_fields) only_fields = only_fields + ["status"] with TimingContext("mongo", "task_by_id_all"): tasks = Task.get_many( company=company_id, query=Q(id=task_id), allow_public=allow_public, override_projection=only_fields, return_dicts=False, ) task = None if not tasks else tasks[0] if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if required_status and not task.status == required_status: raise errors.bad_request.InvalidTaskStatus(expected=required_status) return task @staticmethod def assert_exists( company_id, task_ids, only=None, allow_public=False, return_tasks=True ) -> Optional[Sequence[Task]]: task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids with translate_errors_context(), TimingContext("mongo", "task_exists"): ids = set(task_ids) q = Task.get_many( company=company_id, query=Q(id__in=ids), allow_public=allow_public, return_dicts=False, ) if only: # Make sure to reset fields filters (some fields are excluded by default) since this # is an internal call and specific fields were requested. q = q.all_fields().only(*only) if q.count() != len(ids): raise errors.bad_request.InvalidTaskId(ids=task_ids) if return_tasks: return list(q) @staticmethod def create(call: APICall, fields: dict): identity = call.identity now = datetime.utcnow() return Task( id=create_id(), user=identity.user, company=identity.company, created=now, last_update=now, **fields, ) @staticmethod def validate_execution_model(task, allow_only_public=False): if not task.execution or not task.execution.model: return company = None if allow_only_public else task.company model_id = task.execution.model model = Model.objects( Q(id=model_id) & get_company_or_none_constraint(company) ).first() if not model: raise errors.bad_request.InvalidModelId(model=model_id) return model @classmethod def clone_task( cls, company_id, user_id, task_id, name: Optional[str] = None, comment: Optional[str] = None, parent: Optional[str] = None, project: Optional[str] = None, tags: Optional[Sequence[str]] = None, system_tags: Optional[Sequence[str]] = None, hyperparams: Optional[dict] = None, configuration: Optional[dict] = None, execution_overrides: Optional[dict] = None, validate_references: bool = False, ) -> Task: task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) execution_dict = task.execution.to_proper_dict() if task.execution else {} execution_model_overriden = False params_dict = { field: value for field, value in ( ("hyperparams", hyperparams), ("configuration", configuration), ) if value is not None } if execution_overrides: params_dict["execution"] = {} for legacy_param in ("parameters", "configuration"): legacy_value = execution_overrides.pop(legacy_param, None) if legacy_value is not None: params_dict["execution"] = legacy_value execution_dict = deep_merge(execution_dict, execution_overrides) execution_model_overriden = execution_overrides.get("model") is not None params_prepare_for_save(params_dict, previous_task=task) artifacts = execution_dict.get("artifacts") if artifacts: execution_dict["artifacts"] = [ a for a in artifacts if a.get("mode") != ArtifactModes.output ] now = datetime.utcnow() with translate_errors_context(): new_task = Task( id=create_id(), user=user_id, company=company_id, created=now, last_update=now, name=name or task.name, comment=comment or task.comment, parent=parent or task.parent, project=project or task.project, tags=tags or task.tags, system_tags=system_tags or [], type=task.type, script=task.script, output=Output(destination=task.output.destination) if task.output else None, execution=execution_dict, configuration=params_dict.get("configuration") or task.configuration, hyperparams=params_dict.get("hyperparams") or task.hyperparams, ) cls.validate( new_task, validate_model=validate_references or execution_model_overriden, validate_parent=validate_references or parent, validate_project=validate_references or project, ) new_task.save() if task.project == new_task.project: updated_tags = tags updated_system_tags = system_tags else: updated_tags = new_task.tags updated_system_tags = new_task.system_tags org_bll.update_tags( company_id, Tags.Task, project=new_task.project, tags=updated_tags, system_tags=updated_system_tags, ) return new_task @classmethod def validate( cls, task: Task, validate_model=True, validate_parent=True, validate_project=True, ): if ( validate_parent and task.parent and not Task.get( company=task.company, id=task.parent, _only=("id",), include_public=True ) ): raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent) if ( validate_project and task.project and not Project.get_for_writing(company=task.company, id=task.project) ): raise errors.bad_request.InvalidProjectId(id=task.project) if validate_model: cls.validate_execution_model(task) @staticmethod def get_unique_metric_variants(company_id, project_ids=None): pipeline = [ { "$match": dict( company=company_id, **({"project": {"$in": project_ids}} if project_ids else {}), ) }, {"$project": {"metrics": {"$objectToArray": "$last_metrics"}}}, {"$unwind": "$metrics"}, { "$project": { "metric": "$metrics.k", "variants": {"$objectToArray": "$metrics.v"}, } }, {"$unwind": "$variants"}, { "$group": { "_id": { "metric": "$variants.v.metric", "variant": "$variants.v.variant", }, "metrics": { "$addToSet": { "metric": "$variants.v.metric", "metric_hash": "$metric", "variant": "$variants.v.variant", "variant_hash": "$variants.k", } }, } }, {"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})}, ] with translate_errors_context(): result = Task.aggregate(pipeline) return [r["metrics"][0] for r in result] @staticmethod def set_last_update( task_ids: Collection[str], company_id: str, last_update: datetime ): return Task.objects(id__in=task_ids, company=company_id).update( upsert=False, last_update=last_update ) @staticmethod def update_statistics( task_id: str, company_id: str, last_update: datetime = None, last_iteration: int = None, last_iteration_max: int = None, last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None, last_events: Dict[str, Dict[str, dict]] = None, **extra_updates, ): """ Update task statistics :param task_id: Task's ID. :param company_id: Task's company ID. :param last_update: Last update time. If not provided, defaults to datetime.utcnow(). :param last_iteration: Last reported iteration. Use this to set a value regardless of current task's last iteration value. :param last_iteration_max: Last reported iteration. Use this to conditionally set a value only if the current task's last iteration value is smaller than the provided value. :param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant). :param last_events: Last reported metrics summary (value, metric, event type). :param extra_updates: Extra task updates to include in this update call. :return: """ last_update = last_update or datetime.utcnow() if last_iteration is not None: extra_updates.update(last_iteration=last_iteration) elif last_iteration_max is not None: extra_updates.update(max__last_iteration=last_iteration_max) if last_scalar_values is not None: def op_path(op, *path): return "__".join((op, "last_metrics") + path) for path, value in last_scalar_values: if path[-1] == "min_value": extra_updates[op_path("min", *path[:-1], "min_value")] = value elif path[-1] == "max_value": extra_updates[op_path("max", *path[:-1], "max_value")] = value else: extra_updates[op_path("set", *path)] = value if last_events is not None: def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]: return { event_type: EventStats(last_update=event["timestamp"]) for event_type, event in metric_data.items() } metric_stats = { dbutils.hash_field_name(metric_key): MetricEventStats( metric=metric_key, event_stats_by_type=events_per_type(metric_data) ) for metric_key, metric_data in last_events.items() } extra_updates["metric_stats"] = metric_stats Task.objects(id=task_id, company=company_id).update( upsert=False, last_update=last_update, **extra_updates ) @classmethod def model_set_ready( cls, model_id: str, company_id: str, publish_task: bool, force_publish_task: bool = False, ) -> tuple: with translate_errors_context(): query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) elif model.ready: raise errors.bad_request.ModelIsReady(**query) published_task_data = {} if model.task and publish_task: task = ( Task.objects(id=model.task, company=company_id) .only("id", "status") .first() ) if task and task.status != TaskStatus.published: published_task_data["data"] = cls.publish_task( task_id=model.task, company_id=company_id, publish_model=False, force=force_publish_task, ) published_task_data["id"] = model.task updated = model.update(upsert=False, ready=True) return updated, published_task_data @classmethod def publish_task( cls, task_id: str, company_id: str, publish_model: bool, force: bool, status_reason: str = "", status_message: str = "", ) -> dict: task = cls.get_task_with_access( task_id, company_id=company_id, requires_write_access=True ) if not force: validate_status_change(task.status, TaskStatus.published) previous_task_status = task.status output = task.output or Output() publish_failed = False try: # set state to publishing task.status = TaskStatus.publishing task.save() # publish task models if task.output.model and publish_model: output_model = ( Model.objects(id=task.output.model) .only("id", "task", "ready") .first() ) if output_model and not output_model.ready: cls.model_set_ready( model_id=task.output.model, company_id=company_id, publish_task=False, ) # set task status to published, and update (or set) it's new output (view and models) return ChangeStatusRequest( task=task, new_status=TaskStatus.published, force=force, status_reason=status_reason, status_message=status_message, ).execute(published=datetime.utcnow(), output=output) except Exception as ex: publish_failed = True raise ex finally: if publish_failed: task.status = previous_task_status task.save() @classmethod def stop_task( cls, task_id: str, company_id: str, user_name: str, status_reason: str, force: bool, ) -> dict: """ Stop a running task. Requires task status 'in_progress' and execution_progress 'running', or force=True. Development task or task that has no associated worker is stopped immediately. For a non-development task with worker only the status message is set to 'stopping' to allow the worker to stop the task and report by itself :return: updated task fields """ task = cls.get_task_with_access( task_id, company_id=company_id, only=( "status", "project", "tags", "system_tags", "last_worker", "last_update", ), requires_write_access=True, ) def is_run_by_worker(t: Task) -> bool: """Checks if there is an active worker running the task""" update_timeout = config.get("apiserver.workers.task_update_timeout", 600) return ( t.last_worker and t.last_update and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout ) if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task): new_status = TaskStatus.stopped status_message = f"Stopped by {user_name}" else: new_status = task.status status_message = TaskStatusMessage.stopping return ChangeStatusRequest( task=task, new_status=new_status, status_reason=status_reason, status_message=status_message, force=force, ).execute() @classmethod def add_or_update_artifacts( cls, task_id: str, company_id: str, artifacts: List[ApiArtifact] ) -> Tuple[List[str], List[str]]: key = attrgetter("key", "mode") if not artifacts: return [], [] with translate_errors_context(), TimingContext("mongo", "update_artifacts"): artifacts: List[Artifact] = [ Artifact(**artifact.to_struct()) for artifact in artifacts ] attempts = int(config.get("services.tasks.artifacts.update_attempts", 10)) for retry in range(attempts): task = cls.get_task_with_access( task_id, company_id=company_id, requires_write_access=True ) current = list(map(key, task.execution.artifacts)) updated = [a for a in artifacts if key(a) in current] added = [a for a in artifacts if a not in updated] filter = {"_id": task_id, "company": company_id} update = {} array_filters = None if current: filter["execution.artifacts"] = { "$size": len(current), "$all": [ *( {"$elemMatch": {"key": key, "mode": mode}} for key, mode in current ) ], } else: filter["$or"] = [ {"execution.artifacts": {"$exists": False}}, {"execution.artifacts": {"$size": 0}}, ] if added: update["$push"] = { "execution.artifacts": {"$each": [a.to_mongo() for a in added]} } if updated: update["$set"] = { f"execution.artifacts.$[artifact{index}]": artifact.to_mongo() for index, artifact in enumerate(updated) } array_filters = [ { f"artifact{index}.key": artifact.key, f"artifact{index}.mode": artifact.mode, } for index, artifact in enumerate(updated) ] if not update: return [], [] result: pymongo.results.UpdateResult = Task._get_collection().update_one( filter=filter, update=update, array_filters=array_filters, upsert=False, ) if result.matched_count >= 1: break wait_msec = random() * int( config.get("services.tasks.artifacts.update_retry_msec", 500) ) log.warning( f"Failed to update artifacts for task {task_id} (updated by another party)," f" retrying {retry+1}/{attempts} in {wait_msec}ms" ) sleep(wait_msec / 1000) else: raise errors.server_error.UpdateFailed( "task artifacts updated by another party" ) return [a.key for a in added], [a.key for a in updated] @staticmethod def get_aggregated_project_parameters( company_id, project_ids: Sequence[str] = None, page: int = 0, page_size: int = 500, ) -> Tuple[int, int, Sequence[dict]]: page = max(0, page) page_size = max(1, page_size) pipeline = [ { "$match": { "company": company_id, "hyperparams": {"$exists": True, "$gt": {}}, **({"project": {"$in": project_ids}} if project_ids else {}), } }, {"$project": {"sections": {"$objectToArray": "$hyperparams"}}}, {"$unwind": "$sections"}, { "$project": { "section": "$sections.k", "names": {"$objectToArray": "$sections.v"}, } }, {"$unwind": "$names"}, {"$group": {"_id": {"section": "$section", "name": "$names.k"}}}, {"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})}, { "$group": { "_id": 1, "total": {"$sum": 1}, "results": {"$push": "$$ROOT"}, } }, { "$project": { "total": 1, "results": {"$slice": ["$results", page * page_size, page_size]}, } }, ] with translate_errors_context(): result = next(Task.aggregate(pipeline), None) total = 0 remaining = 0 results = [] if result: total = int(result.get("total", -1)) results = [ { "section": ParameterKeyEscaper.unescape( dpath.get(r, "_id/section") ), "name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")), } for r in result.get("results", []) ] remaining = max(0, total - (len(results) + page * page_size)) return total, remaining, results