from datetime import datetime from typing import Collection, Sequence, Tuple, Optional, Dict import six from mongoengine import Q from redis import StrictRedis from six import string_types import apiserver.database.utils as dbutils from apiserver.apierrors import errors, APIError from apiserver.apimodels.tasks import TaskInputModel from apiserver.bll.queue import QueueBLL from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.metrics import EventStats, MetricEventStats from apiserver.database.model.task.output import Output from apiserver.database.model.task.task import ( Task, TaskStatus, TaskSystemTags, ArtifactModes, ModelItem, Models, DEFAULT_ARTIFACT_MODE, TaskModelNames, TaskModelTypes, ) from apiserver.database.model import EntityVisibility from apiserver.database.model.queue import Queue from apiserver.database.utils import get_company_or_none_constraint, id as create_id from apiserver.es_factory import es_factory from apiserver.redis_manager import redman from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save from .utils import ( ChangeStatusRequest, update_project_time, deleted_prefix, get_last_metric_updates, ) log = config.logger(__file__) org_bll = OrgBLL() queue_bll = QueueBLL() project_bll = ProjectBLL() class TaskBLL: def __init__(self, events_es=None, redis=None): self.events_es = events_es or es_factory.connect("events") self.redis: StrictRedis = redis or redman.connection("apiserver") @staticmethod def get_task_with_access( task_id, company_id, only=None, allow_public=False, requires_write_access=False ) -> Task: """ Gets a task that has a required write access :except errors.bad_request.InvalidTaskId: if the task is not found :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified """ with translate_errors_context(): query = dict(id=task_id, company=company_id) if requires_write_access: task = Task.get_for_writing(_only=only, **query) else: task = Task.get(_only=only, **query, include_public=allow_public) if not task: raise errors.bad_request.InvalidTaskId(**query) return task @staticmethod def get_by_id( company_id, task_id, required_status=None, only_fields=None, allow_public=False, ): if only_fields: if isinstance(only_fields, string_types): only_fields = [only_fields] else: only_fields = list(only_fields) only_fields = only_fields + ["status"] tasks = Task.get_many( company=company_id, query=Q(id=task_id), allow_public=allow_public, override_projection=only_fields, return_dicts=False, ) task = None if not tasks else tasks[0] if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if required_status and not task.status == required_status: raise errors.bad_request.InvalidTaskStatus(expected=required_status) return task @staticmethod def assert_exists( company_id, task_ids, only=None, allow_public=False, return_tasks=True ) -> Optional[Sequence[Task]]: task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids with translate_errors_context(): ids = set(task_ids) q = Task.get_many( company=company_id, query=Q(id__in=ids), allow_public=allow_public, return_dicts=False, ) if only: # Make sure to reset fields filters (some fields are excluded by default) since this # is an internal call and specific fields were requested. q = q.all_fields().only(*only) if q.count() != len(ids): raise errors.bad_request.InvalidTaskId(ids=task_ids) if return_tasks: return list(q) @staticmethod def create(company: str, user: str, fields: dict): now = datetime.utcnow() return Task( id=create_id(), user=user, company=company, created=now, last_update=now, last_change=now, last_changed_by=user, **fields, ) @staticmethod def validate_input_models(task, allow_only_public=False): if not task.models.input: return company = None if allow_only_public else task.company model_ids = set(m.model for m in task.models.input) models = Model.objects( Q(id__in=model_ids) & get_company_or_none_constraint(company) ).only("id") missing = model_ids - {m.id for m in models} if missing: raise errors.bad_request.InvalidModelId(models=missing) return @classmethod def clone_task( cls, company_id: str, user_id: str, task_id: str, name: Optional[str] = None, comment: Optional[str] = None, parent: Optional[str] = None, project: Optional[str] = None, tags: Optional[Sequence[str]] = None, system_tags: Optional[Sequence[str]] = None, hyperparams: Optional[dict] = None, configuration: Optional[dict] = None, container: Optional[dict] = None, execution_overrides: Optional[dict] = None, input_models: Optional[Sequence[TaskInputModel]] = None, validate_references: bool = False, new_project_name: str = None, ) -> Tuple[Task, dict]: validate_tags(tags, system_tags) params_dict = { field: value for field, value in ( ("hyperparams", hyperparams), ("configuration", configuration), ) if value is not None } task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) now = datetime.utcnow() if input_models: input_models = [ ModelItem(model=m.model, name=m.name, updated=now) for m in input_models ] execution_dict = task.execution.to_proper_dict() if task.execution else {} if execution_overrides: execution_model = execution_overrides.pop("model", None) if not input_models and execution_model: input_models = [ ModelItem( model=execution_model, name=TaskModelNames[TaskModelTypes.input], updated=now, ) ] docker_cmd = execution_overrides.pop("docker_cmd", None) if not container and docker_cmd: image, _, arguments = docker_cmd.partition(" ") container = {"image": image, "arguments": arguments} artifacts_prepare_for_save({"execution": execution_overrides}) params_dict["execution"] = {} for legacy_param in ("parameters", "configuration"): legacy_value = execution_overrides.pop(legacy_param, None) if legacy_value is not None: params_dict["execution"] = legacy_value escape_dict_field(execution_overrides, "model_labels") execution_dict.update(execution_overrides) params_prepare_for_save(params_dict, previous_task=task) artifacts = execution_dict.get("artifacts") if artifacts: execution_dict["artifacts"] = { k: a for k, a in artifacts.items() if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output } execution_dict.pop("queue", None) new_project_data = None if not project and new_project_name: # Use a project with the provided name, or create a new project project = ProjectBLL.find_or_create( project_name=new_project_name, user=user_id, company=company_id, description="", ) new_project_data = {"id": project, "name": new_project_name} def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]: if not input_tags: return input_tags return [ tag for tag in input_tags if tag not in [TaskSystemTags.development, EntityVisibility.archived.value] ] def ensure_int_labels(execution: dict) -> dict: if not execution: return execution model_labels = execution.get("model_labels") if model_labels: execution["model_labels"] = {k: int(v) for k, v in model_labels.items()} return execution parent_task = ( task.parent if task.parent and not task.parent.startswith(deleted_prefix) else task.id ) new_task = Task( id=create_id(), user=user_id, company=company_id, created=now, last_update=now, last_change=now, last_changed_by=user_id, name=name or task.name, comment=comment or task.comment, parent=parent or parent_task, project=project or task.project, tags=tags or task.tags, system_tags=system_tags or clean_system_tags(task.system_tags), type=task.type, script=task.script, output=Output(destination=task.output.destination) if task.output else None, models=Models(input=input_models or task.models.input), container=escape_dict(container) or task.container, execution=ensure_int_labels(execution_dict), configuration=params_dict.get("configuration") or task.configuration, hyperparams=params_dict.get("hyperparams") or task.hyperparams, ) cls.validate( new_task, validate_models=validate_references or input_models, validate_parent=validate_references or parent, validate_project=validate_references or project, ) new_task.save() if task.project == new_task.project: updated_tags = tags updated_system_tags = system_tags else: updated_tags = new_task.tags updated_system_tags = new_task.system_tags org_bll.update_tags( company_id, Tags.Task, project=new_task.project, tags=updated_tags, system_tags=updated_system_tags, ) update_project_time(new_task.project) return new_task, new_project_data @classmethod def validate( cls, task: Task, validate_models=True, validate_parent=True, validate_project=True, ): """ Validate task properties according to the flag Task project is always checked for being writable in order to disable the modification of public projects """ if ( validate_parent and task.parent and not task.parent.startswith(deleted_prefix) and not Task.get( company=task.company, id=task.parent, _only=("id",), include_public=True ) ): raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent) if task.project: project = Project.get_for_writing(company=task.company, id=task.project) if validate_project and not project: raise errors.bad_request.InvalidProjectId(id=task.project) if validate_models: cls.validate_input_models(task) @staticmethod def set_last_update( task_ids: Collection[str], company_id: str, last_update: datetime, **extra_updates, ): tasks = Task.objects(id__in=task_ids, company=company_id).only( "status", "started" ) count = 0 for task in tasks: updates = extra_updates if task.status == TaskStatus.in_progress and task.started: updates = { "active_duration": ( datetime.utcnow() - task.started ).total_seconds(), **extra_updates, } count += Task.objects(id=task.id, company=company_id).update( upsert=False, last_update=last_update, last_change=last_update, **updates, ) return count @staticmethod def update_statistics( task_id: str, company_id: str, last_update: datetime = None, last_iteration: int = None, last_iteration_max: int = None, last_scalar_events: Dict[str, Dict[str, dict]] = None, last_events: Dict[str, Dict[str, dict]] = None, **extra_updates, ): """ Update task statistics :param task_id: Task's ID. :param company_id: Task's company ID. :param last_update: Last update time. If not provided, defaults to datetime.utcnow(). :param last_iteration: Last reported iteration. Use this to set a value regardless of current task's last iteration value. :param last_iteration_max: Last reported iteration. Use this to conditionally set a value only if the current task's last iteration value is smaller than the provided value. :param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant). :param last_events: Last reported metrics summary (value, metric, event type). :param extra_updates: Extra task updates to include in this update call. :return: """ last_update = last_update or datetime.utcnow() if last_iteration is not None: extra_updates.update(last_iteration=last_iteration) elif last_iteration_max is not None: extra_updates.update(max__last_iteration=last_iteration_max) raw_updates = {} if last_scalar_events is not None: get_last_metric_updates( task_id=task_id, last_scalar_events=last_scalar_events, raw_updates=raw_updates, extra_updates=extra_updates, ) if last_events is not None: def events_per_type(metric_data_: Dict[str, dict]) -> Dict[str, EventStats]: return { event_type: EventStats(last_update=event["timestamp"]) for event_type, event in metric_data_.items() } metric_stats = { dbutils.hash_field_name(metric_key): MetricEventStats( metric=metric_key, event_stats_by_type=events_per_type(metric_data) ) for metric_key, metric_data in last_events.items() } extra_updates["metric_stats"] = metric_stats ret = TaskBLL.set_last_update( task_ids=[task_id], company_id=company_id, last_update=last_update, **extra_updates, ) if ret and raw_updates: Task.objects(id=task_id).update_one(__raw__=[{"$set": raw_updates}]) return ret @staticmethod def remove_task_from_all_queues(company_id: str, task: Task) -> int: return Queue.objects(company=company_id, entries__task=task.id).update( pull__entries__task=task.id, last_update=datetime.utcnow() ) @classmethod def dequeue_and_change_status( cls, task: Task, company_id: str, user_id: str, status_message: str, status_reason: str, remove_from_all_queues=False, ): try: cls.dequeue(task, company_id, silent_fail=True) except APIError: # dequeue may fail if the queue was deleted pass if remove_from_all_queues: cls.remove_task_from_all_queues(company_id=company_id, task=task) if task.status not in [TaskStatus.queued, TaskStatus.in_progress]: return {"updated": 0} return ChangeStatusRequest( task=task, new_status=task.enqueue_status or TaskStatus.created, status_reason=status_reason, status_message=status_message, user_id=user_id, ).execute(enqueue_status=None) @classmethod def dequeue(cls, task: Task, company_id: str, silent_fail=False): """ Dequeue the task from the queue :param task: task to dequeue :param company_id: task's company ID. :param silent_fail: do not throw exceptions. APIError is still thrown :raise errors.bad_request.InvalidTaskId: if the task's status is not queued :raise errors.bad_request.MissingRequiredFields: if the task is not queued :raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails :return: the result of queues.remove_task call. None in case of silent failure """ if task.status not in (TaskStatus.queued,): if silent_fail: return raise errors.bad_request.InvalidTaskId( status=task.status, expected=TaskStatus.queued ) if not task.execution or not task.execution.queue: if silent_fail: return raise errors.bad_request.MissingRequiredFields( "task has no queue value", field="execution.queue" ) return { "removed": queue_bll.remove_task( company_id=company_id, queue_id=task.execution.queue, task_id=task.id ) }