From 31d3b6dbc526df94815042a22c731f05267891e3 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 31 Mar 2021 23:51:25 +0300 Subject: [PATCH] Support stopping instead of resetting in Task.execute_remotely() in case server supports enqueueing stopped tasks --- clearml/backend_api/session/session.py | 29 +++++++++++++++++++------- clearml/task.py | 8 +++++-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/clearml/backend_api/session/session.py b/clearml/backend_api/session/session.py index c2611a32..6b6a8731 100644 --- a/clearml/backend_api/session/session.py +++ b/clearml/backend_api/session/session.py @@ -57,6 +57,7 @@ class Session(TokenManager): _client = [(__package__.partition(".")[0], __version__)] api_version = '2.1' + max_api_version = '2.1' default_demo_host = "https://demoapi.demo.clear.ml" default_host = default_demo_host default_web = "https://demoapp.demo.clear.ml" @@ -180,7 +181,7 @@ class Session(TokenManager): if not any(True for c in Session._client if c[0] == 'clearml-server'): Session._client.append(('clearml-server', token_dict.get('server_version'), )) - Session.api_version = str(api_version) + Session.max_api_version = Session.api_version = str(api_version) except (jwt.DecodeError, ValueError): (self._logger or get_logger()).warning( "Failed parsing server API level, defaulting to {}".format(Session.api_version)) @@ -193,7 +194,7 @@ class Session(TokenManager): self.__class__._sessions_created += 1 if self.force_max_api_version and self.check_min_api_version(self.force_max_api_version): - Session.api_version = str(self.force_max_api_version) + Session.max_api_version = Session.api_version = str(self.force_max_api_version) def _send_request( self, @@ -549,10 +550,6 @@ class Session(TokenManager): """ Return True if Session.api_version is greater or equal >= to min_api_version """ - def version_tuple(v): - v = tuple(map(int, (v.split(".")))) - return v + (0,) * max(0, 3 - len(v)) - # If no session was created, create a default one, in order to get the backend api version. if cls._sessions_created <= 0: if cls._offline_mode: @@ -567,7 +564,7 @@ class Session(TokenManager): cls._offline_default_version = str(offline_api) except ValueError: pass - cls.api_version = cls._offline_default_version + cls.max_api_version = cls.api_version = cls._offline_default_version else: # noinspection PyBroadException try: @@ -575,7 +572,18 @@ class Session(TokenManager): except Exception: pass - return version_tuple(cls.api_version) >= version_tuple(str(min_api_version)) + return cls._version_tuple(cls.api_version) >= cls._version_tuple(str(min_api_version)) + + @classmethod + def check_min_api_server_version(cls, min_api_version): + """ + Return True if Session.max_api_version is greater or equal >= to min_api_version + Notice this is the api version server reported, not the current SDK max supported api version + """ + if cls.check_min_api_version(min_api_version): + return True + + return cls._version_tuple(cls.max_api_version) >= cls._version_tuple(str(min_api_version)) @classmethod def get_worker_host_name(cls): @@ -586,6 +594,11 @@ class Session(TokenManager): def get_clients(cls): return cls._client + @staticmethod + def _version_tuple(v): + v = tuple(map(int, (v.split(".")))) + return v + (0,) * max(0, 3 - len(v)) + def _do_refresh_token(self, old_token, exp=None): """ TokenManager abstract method implementation. Here we ignore the old token and simply obtain a new token. diff --git a/clearml/task.py b/clearml/task.py index 62c6bfdf..9737f904 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -1851,7 +1851,7 @@ class Task(_Task): If ``clone==False``, then ``exit_process`` must be ``True``. - :return Task: return the task object of the newly generated remotely excuting task + :return Task: return the task object of the newly generated remotely executing task """ # do nothing, we are running remotely if running_remotely() and self.is_main_task(): @@ -1878,7 +1878,11 @@ class Task(_Task): task = Task.clone(self) else: task = self - self.reset() + # check if the server supports enqueueing aborted/stopped Tasks + if Session.check_min_api_server_version('2.10'): + self.mark_stopped(force=True) + else: + self.reset() # enqueue ourselves if queue_name: