clearml-server/apiserver/bll/queue/queue_bll.py
2021-01-05 16:22:34 +02:00

271 lines
10 KiB
Python

from collections import defaultdict
from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from elasticsearch import Elasticsearch
import database
import es_factory
from apierrors import errors
from bll.queue.queue_metrics import QueueMetrics
from bll.workers import WorkerBLL
from config import config
from database.errors import translate_errors_context
from database.model.queue import Queue, Entry
log = config.logger(__file__)
class QueueBLL(object):
def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
self.worker_bll = worker_bll or WorkerBLL()
self.es = es or es_factory.connect("workers")
self._metrics = QueueMetrics(self.es)
@property
def metrics(self) -> QueueMetrics:
return self._metrics
@staticmethod
def create(
company_id: str,
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
now = datetime.utcnow()
queue = Queue(
id=database.utils.id(),
company=company_id,
created=now,
name=name,
tags=tags or [],
system_tags=system_tags or [],
last_update=now,
)
queue.save()
return queue
def get_by_id(
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
) -> Queue:
"""
Get queue by id
:raise errors.bad_request.InvalidQueueId: if the queue is not found
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
qs = Queue.objects(**query)
if only:
qs = qs.only(*only)
queue = qs.first()
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
return queue
@classmethod
def get_queue_with_task(cls, company_id: str, queue_id: str, task_id: str) -> Queue:
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(entries__task=task_id, **query).first()
if not queue:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
)
return queue
def get_default(self, company_id: str) -> Queue:
"""
Get the default queue
:raise errors.bad_request.NoDefaultQueue: if the default queue not found
:raise errors.bad_request.MultipleDefaultQueues: if more than one default queue is found
"""
with translate_errors_context():
res = Queue.objects(company=company_id, system_tags="default").only(
"id", "name"
)
if not res:
raise errors.bad_request.NoDefaultQueue()
if len(res) > 1:
raise errors.bad_request.MultipleDefaultQueues(
queues=tuple(r.id for r in res)
)
return res.first()
def update(
self, company_id: str, queue_id: str, **update_fields
) -> Tuple[int, dict]:
"""
Partial update of the queue from update_fields
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:return: number of updated objects and updated fields dictionary
"""
with translate_errors_context():
# validate the queue exists
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries and not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
queue.delete()
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
return Queue.get_many(
company=company_id, parameters=query_dict, query_dict=query_dict
)
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
"""
projection = Queue.get_extra_projection("entries.task.name")
with translate_errors_context():
res = Queue.get_many_with_join(
company=company_id,
query_dict=query_dict,
override_projection=projection,
)
queue_workers = defaultdict(list)
for worker in self.worker_bll.get_all(company_id):
for queue in worker.queues:
queue_workers[queue].append(worker)
for item in res:
item["workers"] = [
{
"name": w.id,
"ip": w.ip,
"task": w.task.to_struct() if w.task else None,
}
for w in queue_workers.get(item["id"], [])
]
return res
def add_task(self, company_id: str, queue_id: str, task_id: str) -> dict:
"""
Add the task to the queue and return the queue update results
:raise errors.bad_request.TaskAlreadyQueued: if the task is already in the queue
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the queue update operation failed
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if any(e.task == task_id for e in queue.entries):
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
entry = Entry(added=datetime.utcnow(), task=task_id)
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
push__entries=entry, last_update=datetime.utcnow(), upsert=False
)
if not res:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
)
return res
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
"""
Atomically pop and return the first task from the queue (or None)
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
if not queue.entries:
return
try:
Queue.objects(**query).update(last_update=datetime.utcnow())
except Exception:
log.exception("Error while updating Queue.last_update")
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
"""
Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
entries_to_remove = [e for e in queue.entries if e.task == task_id]
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
return len(entries_to_remove) if res else 0
def reposition_task(
self,
company_id: str,
queue_id: str,
task_id: str,
pos_func: Callable[[int], int],
) -> int:
"""
Moves the task in the queue to the position calculated by pos_func
Returns the updated task position in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
new_position = pos_func(position)
if new_position != position:
entry = queue.entries[position]
query = dict(id=queue_id, company=company_id)
updated = Queue.objects(entries__task=task_id, **query).update_one(
pull__entries=entry, last_update=datetime.utcnow()
)
if not updated:
raise errors.bad_request.RemovedDuringReposition(
task=task_id, **query
)
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
if new_position >= 0:
inst["$push"]["entries"]["$position"] = new_position
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
__raw__=inst
)
if not res:
raise errors.bad_request.FailedAddingDuringReposition(
task=task_id, **query
)
return new_position