import re from collections import OrderedDict from datetime import datetime, timedelta from operator import attrgetter from random import random from time import sleep from typing import Collection, Sequence, Tuple, Any, Optional, List import pymongo.results import six from mongoengine import Q from six import string_types import es_factory from apierrors import errors from apimodels.tasks import Artifact as ApiArtifact from config import config from database.errors import translate_errors_context from database.model.model import Model from database.model.project import Project from database.model.task.output import Output from database.model.task.task import ( Task, TaskStatus, TaskStatusMessage, TaskSystemTags, ) from database.utils import get_company_or_none_constraint, id as create_id from service_repo import APICall from timing_context import TimingContext from utilities.threads_manager import ThreadsManager from .utils import ChangeStatusRequest, validate_status_change log = config.logger(__file__) class TaskBLL(object): threads = ThreadsManager("TaskBLL") def __init__(self, events_es=None): self.events_es = ( events_es if events_es is not None else es_factory.connect("events") ) @staticmethod def get_task_with_access( task_id, company_id, only=None, allow_public=False, requires_write_access=False ) -> Task: """ Gets a task that has a required write access :except errors.bad_request.InvalidTaskId: if the task is not found :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified """ with translate_errors_context(): query = dict(id=task_id, company=company_id) with TimingContext("mongo", "task_with_access"): if requires_write_access: task = Task.get_for_writing(_only=only, **query) else: task = Task.get(_only=only, **query, include_public=allow_public) if not task: raise errors.bad_request.InvalidTaskId(**query) return task @staticmethod def get_by_id( company_id, task_id, required_status=None, required_dataset=None, only_fields=None, ): with TimingContext("mongo", "task_by_id_all"): qs = Task.objects(id=task_id, company=company_id) if only_fields: qs = ( qs.only(only_fields) if isinstance(only_fields, string_types) else qs.only(*only_fields) ) qs = qs.only( "status", "input" ) # make sure all fields we rely on here are also returned task = qs.first() if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if required_status and not task.status == required_status: raise errors.bad_request.InvalidTaskStatus(expected=required_status) if required_dataset and required_dataset not in ( entry.dataset for entry in task.input.view.entries ): raise errors.bad_request.InvalidId( "not in input view", dataset=required_dataset ) return task @staticmethod def assert_exists(company_id, task_ids, only=None, allow_public=False): task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids with translate_errors_context(), TimingContext("mongo", "task_exists"): ids = set(task_ids) q = Task.get_many( company=company_id, query=Q(id__in=ids), allow_public=allow_public, return_dicts=False, ) if only: res = q.only(*only) count = len(res) else: count = q.count() res = q.first() if count != len(ids): raise errors.bad_request.InvalidTaskId(ids=task_ids) return res @staticmethod def create(call: APICall, fields: dict): identity = call.identity now = datetime.utcnow() return Task( id=create_id(), user=identity.user, company=identity.company, created=now, last_update=now, **fields, ) @staticmethod def validate_execution_model(task, allow_only_public=False): if not task.execution or not task.execution.model: return company = None if allow_only_public else task.company model_id = task.execution.model model = Model.objects( Q(id=model_id) & get_company_or_none_constraint(company) ).first() if not model: raise errors.bad_request.InvalidModelId(model=model_id) return model @classmethod def validate(cls, task: Task): assert isinstance(task, Task) if task.parent and not Task.get( company=task.company, id=task.parent, _only=("id",), include_public=True ): raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent) if task.project: Project.get_for_writing(company=task.company, id=task.project) cls.validate_execution_model(task) if task.execution: if task.execution.parameters: cls._validate_execution_parameters(task.execution.parameters) @staticmethod def _validate_execution_parameters(parameters): invalid_keys = [k for k in parameters if re.search(r"\s", k)] if invalid_keys: raise errors.bad_request.ValidationError( "execution.parameters keys contain whitespace", keys=invalid_keys ) @staticmethod def get_unique_metric_variants(company_id, project_ids=None): pipeline = [ { "$match": dict( company=company_id, **({"project": {"$in": project_ids}} if project_ids else {}), ) }, {"$project": {"metrics": {"$objectToArray": "$last_metrics"}}}, {"$unwind": "$metrics"}, { "$project": { "metric": "$metrics.k", "variants": {"$objectToArray": "$metrics.v"}, } }, {"$unwind": "$variants"}, { "$group": { "_id": { "metric": "$variants.v.metric", "variant": "$variants.v.variant", }, "metrics": { "$addToSet": { "metric": "$variants.v.metric", "metric_hash": "$metric", "variant": "$variants.v.variant", "variant_hash": "$variants.k", } }, } }, {"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})}, ] with translate_errors_context(): result = Task.aggregate(*pipeline) return [r["metrics"][0] for r in result] @staticmethod def set_last_update( task_ids: Collection[str], company_id: str, last_update: datetime ): return Task.objects(id__in=task_ids, company=company_id).update( upsert=False, last_update=last_update ) @staticmethod def update_statistics( task_id: str, company_id: str, last_update: datetime = None, last_iteration: int = None, last_iteration_max: int = None, last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None, **extra_updates, ): """ Update task statistics :param task_id: Task's ID. :param company_id: Task's company ID. :param last_update: Last update time. If not provided, defaults to datetime.utcnow(). :param last_iteration: Last reported iteration. Use this to set a value regardless of current task's last iteration value. :param last_iteration_max: Last reported iteration. Use this to conditionally set a value only if the current task's last iteration value is smaller than the provided value. :param last_values: Last reported metrics summary (value, metric, variant). :param extra_updates: Extra task updates to include in this update call. :return: """ last_update = last_update or datetime.utcnow() if last_iteration is not None: extra_updates.update(last_iteration=last_iteration) elif last_iteration_max is not None: extra_updates.update(max__last_iteration=last_iteration_max) if last_values is not None: def op_path(op, *path): return "__".join((op, "last_metrics") + path) for path, value in last_values: extra_updates[op_path("set", *path)] = value if path[-1] == "value": extra_updates[op_path("min", *path[:-1], "min_value")] = value extra_updates[op_path("max", *path[:-1], "max_value")] = value Task.objects(id=task_id, company=company_id).update( upsert=False, last_update=last_update, **extra_updates ) @classmethod def model_set_ready( cls, model_id: str, company_id: str, publish_task: bool, force_publish_task: bool = False, ) -> tuple: with translate_errors_context(): query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) elif model.ready: raise errors.bad_request.ModelIsReady(**query) published_task_data = {} if model.task and publish_task: task = ( Task.objects(id=model.task, company=company_id) .only("id", "status") .first() ) if task and task.status != TaskStatus.published: published_task_data["data"] = cls.publish_task( task_id=model.task, company_id=company_id, publish_model=False, force=force_publish_task, ) published_task_data["id"] = model.task updated = model.update(upsert=False, ready=True) return updated, published_task_data @classmethod def publish_task( cls, task_id: str, company_id: str, publish_model: bool, force: bool, status_reason: str = "", status_message: str = "", ) -> dict: task = cls.get_task_with_access( task_id, company_id=company_id, requires_write_access=True ) if not force: validate_status_change(task.status, TaskStatus.published) previous_task_status = task.status output = task.output or Output() publish_failed = False try: # set state to publishing task.status = TaskStatus.publishing task.save() # publish task models if task.output.model and publish_model: output_model = ( Model.objects(id=task.output.model) .only("id", "task", "ready") .first() ) if output_model and not output_model.ready: cls.model_set_ready( model_id=task.output.model, company_id=company_id, publish_task=False, ) # set task status to published, and update (or set) it's new output (view and models) return ChangeStatusRequest( task=task, new_status=TaskStatus.published, force=force, status_reason=status_reason, status_message=status_message, ).execute(published=datetime.utcnow(), output=output) except Exception as ex: publish_failed = True raise ex finally: if publish_failed: task.status = previous_task_status task.save() @classmethod def stop_task( cls, task_id: str, company_id: str, user_name: str, status_reason: str, force: bool, ) -> dict: """ Stop a running task. Requires task status 'in_progress' and execution_progress 'running', or force=True. Development task or task that has no associated worker is stopped immediately. For a non-development task with worker only the status message is set to 'stopping' to allow the worker to stop the task and report by itself :return: updated task fields """ task = cls.get_task_with_access( task_id, company_id=company_id, only=( "status", "project", "tags", "system_tags", "last_worker", "last_update", ), requires_write_access=True, ) def is_run_by_worker(t: Task) -> bool: """Checks if there is an active worker running the task""" update_timeout = config.get("apiserver.workers.task_update_timeout", 600) return ( t.last_worker and t.last_update and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout ) if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task): new_status = TaskStatus.stopped status_message = f"Stopped by {user_name}" else: new_status = task.status status_message = TaskStatusMessage.stopping return ChangeStatusRequest( task=task, new_status=new_status, status_reason=status_reason, status_message=status_message, force=force, ).execute() @classmethod def add_or_update_artifacts( cls, task_id: str, company_id: str, artifacts: List[ApiArtifact] ) -> Tuple[List[str], List[str]]: key = attrgetter("key", "mode") if not artifacts: return [], [] with translate_errors_context(), TimingContext("mongo", "update_artifacts"): artifacts: List[Artifact] = [ Artifact(**artifact.to_struct()) for artifact in artifacts ] attempts = int(config.get("services.tasks.artifacts.update_attempts", 10)) for retry in range(attempts): task = cls.get_task_with_access( task_id, company_id=company_id, requires_write_access=True ) current = list(map(key, task.execution.artifacts)) updated = [a for a in artifacts if key(a) in current] added = [a for a in artifacts if a not in updated] filter = {"_id": task_id, "company": company_id} update = {} array_filters = None if current: filter["execution.artifacts"] = { "$size": len(current), "$all": [ *( {"$elemMatch": {"key": key, "mode": mode}} for key, mode in current ) ], } else: filter["$or"] = [ {"execution.artifacts": {"$exists": False}}, {"execution.artifacts": {"$size": 0}}, ] if added: update["$push"] = { "execution.artifacts": {"$each": [a.to_mongo() for a in added]} } if updated: update["$set"] = { f"execution.artifacts.$[artifact{index}]": artifact.to_mongo() for index, artifact in enumerate(updated) } array_filters = [ { f"artifact{index}.key": artifact.key, f"artifact{index}.mode": artifact.mode, } for index, artifact in enumerate(updated) ] if not update: return [], [] result: pymongo.results.UpdateResult = Task._get_collection().update_one( filter=filter, update=update, array_filters=array_filters, upsert=False, ) if result.matched_count >= 1: break wait_msec = random() * int( config.get("services.tasks.artifacts.update_retry_msec", 500) ) log.warning( f"Failed to update artifacts for task {task_id} (updated by another party)," f" retrying {retry+1}/{attempts} in {wait_msec}ms" ) sleep(wait_msec / 1000) else: raise errors.server_error.UpdateFailed( "task artifacts updated by another party" ) return [a.key for a in added], [a.key for a in updated] @classmethod @threads.register("non_responsive_tasks_watchdog", daemon=True) def start_non_responsive_tasks_watchdog(cls): log = config.logger("non_responsive_tasks_watchdog") relevant_status = (TaskStatus.in_progress,) threshold = timedelta( seconds=config.get( "services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200 ) ) while True: sleep( config.get( "services.tasks.non_responsive_tasks_watchdog.watch_interval_sec", 900, ) ) try: ref_time = datetime.utcnow() - threshold log.info( f"Starting cleanup cycle for running tasks last updated before {ref_time}" ) tasks = list( Task.objects( status__in=relevant_status, last_update__lt=ref_time ).only("id", "name", "status", "project", "last_update") ) if tasks: log.info(f"Stopping {len(tasks)} non-responsive tasks") for task in tasks: log.info( f"Stopping {task.id} ({task.name}), last updated at {task.last_update}" ) ChangeStatusRequest( task=task, new_status=TaskStatus.stopped, status_reason="Forced stop (non-responsive)", status_message="Forced stop (non-responsive)", force=True, ).execute() log.info(f"Done") except Exception as ex: log.exception(f"Failed stopping tasks: {str(ex)}") @staticmethod def get_aggregated_project_execution_parameters( company_id, project_ids: Sequence[str] = None, page: int = 0, page_size: int = 500, ) -> Tuple[int, int, Sequence[str]]: page = max(0, page) page_size = max(1, page_size) pipeline = [ { "$match": { "company": company_id, "execution.parameters": {"$exists": True, "$gt": {}}, **({"project": {"$in": project_ids}} if project_ids else {}), } }, {"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}}, {"$unwind": "$parameters"}, {"$group": {"_id": "$parameters.k"}}, {"$sort": {"_id": 1}}, { "$group": { "_id": 1, "total": {"$sum": 1}, "results": {"$push": "$$ROOT"}, } }, { "$project": { "total": 1, "results": {"$slice": ["$results", page * page_size, page_size]}, } }, ] with translate_errors_context(): result = next(Task.aggregate(*pipeline), None) total = 0 remaining = 0 results = [] if result: total = int(result.get("total", -1)) results = [r["_id"] for r in result.get("results", [])] remaining = max(0, total - (len(results) + page * page_size)) return total, remaining, results