Fix use same state transition if supported by the server (instead of stopping the task before re-enqueue)

This commit is contained in:
allegroai 2024-08-27 22:54:45 +03:00
parent 99e1e54f94
commit b8c762401b
2 changed files with 25 additions and 5 deletions

View File

@ -64,6 +64,7 @@ class Session(TokenManager):
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
force_max_api_version = ENV_FORCE_MAX_API_VERSION.get()
server_version = "1.0.0"
# TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [
@ -191,6 +192,7 @@ class Session(TokenManager):
Session.api_version = str(api_version)
Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic")
Session.server_version = token_dict.get('server_version', self.server_version)
except (jwt.DecodeError, ValueError):
pass
@ -651,11 +653,14 @@ class Session(TokenManager):
"""
Return True if Session.api_version is greater or equal >= to min_api_version
"""
def version_tuple(v):
v = tuple(map(int, (v.split("."))))
return v + (0,) * max(0, 3 - len(v))
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
@classmethod
def check_min_server_version(cls, min_server_version):
"""
Return True if Session.server_version is greater or equal >= to min_server_version
"""
return version_tuple(cls.server_version) >= version_tuple(str(min_server_version))
def _do_refresh_token(self, current_token, exp=None):
""" TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token.
@ -733,3 +738,8 @@ class Session(TokenManager):
def propagate_exceptions_on_send(self, value):
# type: (bool) -> None
self._propagate_exceptions_on_send = value
def version_tuple(v):
v = tuple(map(int, (v.split("."))))
return v + (0,) * max(0, 3 - len(v))

View File

@ -194,6 +194,10 @@ class K8sIntegration(Worker):
self._min_cleanup_interval_per_ns_sec = 1.0
self._last_pod_cleanup_per_ns = defaultdict(lambda: 0.)
self._server_supports_same_state_transition = (
self._session.feature_set != "basic" and self._session.check_min_server_version("3.22.3")
)
def _create_daemon_instance(self, cls_, **kwargs):
return cls_(agent=self, **kwargs)
@ -435,7 +439,9 @@ class K8sIntegration(Worker):
if self._is_same_tenant(task_session):
try:
print('Pushing task {} into temporary pending queue'.format(task_id))
_ = session.api_client.tasks.stop(task_id, force=True, status_reason="moving to k8s pending queue")
if not self._server_supports_same_state_transition:
_ = session.api_client.tasks.stop(task_id, force=True, status_reason="moving to k8s pending queue")
# Just make sure to clean up in case the task is stuck in the queue (known issue)
self._session.api_client.queues.remove_task(
@ -956,7 +962,7 @@ class K8sIntegration(Worker):
result = self._session.get(
service='tasks',
action='get_all',
json={"id": task_ids, "status": ["in_progress", "queued"], "only_fields": ["id", "status"]},
json={"id": task_ids, "status": ["in_progress", "queued"], "only_fields": ["id", "status", "status_reason"]},
method=Request.def_method,
)
tasks_to_abort = result["tasks"]
@ -966,9 +972,13 @@ class K8sIntegration(Worker):
for task in tasks_to_abort:
task_id = task.get("id")
status = task.get("status")
status_reason = (task.get("status_reason") or "").lower()
if not task_id or not status:
self.log.warning('Failed getting task information: id={}, status={}'.format(task_id, status))
continue
if status == "queued" and "pushed back by policy manager" in status_reason:
# Task was pushed back to policy queue by policy manager, don't touch it
continue
try:
if status == "queued":
self._session.get(