diff --git a/clearml_agent/backend_api/session/session.py b/clearml_agent/backend_api/session/session.py index 22c59a5..df71cb3 100644 --- a/clearml_agent/backend_api/session/session.py +++ b/clearml_agent/backend_api/session/session.py @@ -64,6 +64,7 @@ class Session(TokenManager): default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" force_max_api_version = ENV_FORCE_MAX_API_VERSION.get() + server_version = "1.0.0" # TODO: add requests.codes.gateway_timeout once we support async commits _retry_codes = [ @@ -191,6 +192,7 @@ class Session(TokenManager): Session.api_version = str(api_version) Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic") + Session.server_version = token_dict.get('server_version', self.server_version) except (jwt.DecodeError, ValueError): pass @@ -651,11 +653,14 @@ class Session(TokenManager): """ Return True if Session.api_version is greater or equal >= to min_api_version """ - def version_tuple(v): - v = tuple(map(int, (v.split(".")))) - return v + (0,) * max(0, 3 - len(v)) return version_tuple(cls.api_version) >= version_tuple(str(min_api_version)) + @classmethod + def check_min_server_version(cls, min_server_version): + """ + Return True if Session.server_version is greater or equal >= to min_server_version + """ + return version_tuple(cls.server_version) >= version_tuple(str(min_server_version)) def _do_refresh_token(self, current_token, exp=None): """ TokenManager abstract method implementation. Here we ignore the old token and simply obtain a new token. @@ -733,3 +738,8 @@ class Session(TokenManager): def propagate_exceptions_on_send(self, value): # type: (bool) -> None self._propagate_exceptions_on_send = value + + +def version_tuple(v): + v = tuple(map(int, (v.split(".")))) + return v + (0,) * max(0, 3 - len(v)) diff --git a/clearml_agent/glue/k8s.py b/clearml_agent/glue/k8s.py index e5cf49f..798d779 100644 --- a/clearml_agent/glue/k8s.py +++ b/clearml_agent/glue/k8s.py @@ -194,6 +194,10 @@ class K8sIntegration(Worker): self._min_cleanup_interval_per_ns_sec = 1.0 self._last_pod_cleanup_per_ns = defaultdict(lambda: 0.) + self._server_supports_same_state_transition = ( + self._session.feature_set != "basic" and self._session.check_min_server_version("3.22.3") + ) + def _create_daemon_instance(self, cls_, **kwargs): return cls_(agent=self, **kwargs) @@ -435,7 +439,9 @@ class K8sIntegration(Worker): if self._is_same_tenant(task_session): try: print('Pushing task {} into temporary pending queue'.format(task_id)) - _ = session.api_client.tasks.stop(task_id, force=True, status_reason="moving to k8s pending queue") + + if not self._server_supports_same_state_transition: + _ = session.api_client.tasks.stop(task_id, force=True, status_reason="moving to k8s pending queue") # Just make sure to clean up in case the task is stuck in the queue (known issue) self._session.api_client.queues.remove_task( @@ -956,7 +962,7 @@ class K8sIntegration(Worker): result = self._session.get( service='tasks', action='get_all', - json={"id": task_ids, "status": ["in_progress", "queued"], "only_fields": ["id", "status"]}, + json={"id": task_ids, "status": ["in_progress", "queued"], "only_fields": ["id", "status", "status_reason"]}, method=Request.def_method, ) tasks_to_abort = result["tasks"] @@ -966,9 +972,13 @@ class K8sIntegration(Worker): for task in tasks_to_abort: task_id = task.get("id") status = task.get("status") + status_reason = (task.get("status_reason") or "").lower() if not task_id or not status: self.log.warning('Failed getting task information: id={}, status={}'.format(task_id, status)) continue + if status == "queued" and "pushed back by policy manager" in status_reason: + # Task was pushed back to policy queue by policy manager, don't touch it + continue try: if status == "queued": self._session.get(