Support stopping instead of resetting in Task.execute_remotely() in case server supports enqueueing stopped tasks

This commit is contained in:
allegroai 2021-03-31 23:51:25 +03:00
parent 1986ec43fd
commit 31d3b6dbc5
2 changed files with 27 additions and 10 deletions

View File

@ -57,6 +57,7 @@ class Session(TokenManager):
_client = [(__package__.partition(".")[0], __version__)]
api_version = '2.1'
max_api_version = '2.1'
default_demo_host = "https://demoapi.demo.clear.ml"
default_host = default_demo_host
default_web = "https://demoapp.demo.clear.ml"
@ -180,7 +181,7 @@ class Session(TokenManager):
if not any(True for c in Session._client if c[0] == 'clearml-server'):
Session._client.append(('clearml-server', token_dict.get('server_version'), ))
Session.api_version = str(api_version)
Session.max_api_version = Session.api_version = str(api_version)
except (jwt.DecodeError, ValueError):
(self._logger or get_logger()).warning(
"Failed parsing server API level, defaulting to {}".format(Session.api_version))
@ -193,7 +194,7 @@ class Session(TokenManager):
self.__class__._sessions_created += 1
if self.force_max_api_version and self.check_min_api_version(self.force_max_api_version):
Session.api_version = str(self.force_max_api_version)
Session.max_api_version = Session.api_version = str(self.force_max_api_version)
def _send_request(
self,
@ -549,10 +550,6 @@ class Session(TokenManager):
"""
Return True if Session.api_version is greater or equal >= to min_api_version
"""
def version_tuple(v):
v = tuple(map(int, (v.split("."))))
return v + (0,) * max(0, 3 - len(v))
# If no session was created, create a default one, in order to get the backend api version.
if cls._sessions_created <= 0:
if cls._offline_mode:
@ -567,7 +564,7 @@ class Session(TokenManager):
cls._offline_default_version = str(offline_api)
except ValueError:
pass
cls.api_version = cls._offline_default_version
cls.max_api_version = cls.api_version = cls._offline_default_version
else:
# noinspection PyBroadException
try:
@ -575,7 +572,18 @@ class Session(TokenManager):
except Exception:
pass
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
return cls._version_tuple(cls.api_version) >= cls._version_tuple(str(min_api_version))
@classmethod
def check_min_api_server_version(cls, min_api_version):
"""
Return True if Session.max_api_version is greater or equal >= to min_api_version
Notice this is the api version server reported, not the current SDK max supported api version
"""
if cls.check_min_api_version(min_api_version):
return True
return cls._version_tuple(cls.max_api_version) >= cls._version_tuple(str(min_api_version))
@classmethod
def get_worker_host_name(cls):
@ -586,6 +594,11 @@ class Session(TokenManager):
def get_clients(cls):
return cls._client
@staticmethod
def _version_tuple(v):
v = tuple(map(int, (v.split("."))))
return v + (0,) * max(0, 3 - len(v))
def _do_refresh_token(self, old_token, exp=None):
""" TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token.

View File

@ -1851,7 +1851,7 @@ class Task(_Task):
If ``clone==False``, then ``exit_process`` must be ``True``.
:return Task: return the task object of the newly generated remotely excuting task
:return Task: return the task object of the newly generated remotely executing task
"""
# do nothing, we are running remotely
if running_remotely() and self.is_main_task():
@ -1878,7 +1878,11 @@ class Task(_Task):
task = Task.clone(self)
else:
task = self
self.reset()
# check if the server supports enqueueing aborted/stopped Tasks
if Session.check_min_api_server_version('2.10'):
self.mark_stopped(force=True)
else:
self.reset()
# enqueue ourselves
if queue_name: