Fix Only send pip freeze update on RANK 0, only update task status on exit on RANK 0

This commit is contained in:
allegroai 2024-07-29 17:40:24 +03:00
parent 79d0abe707
commit d9f2a1999a

View File

@ -2098,11 +2098,7 @@ class Worker(ServiceCommandSection):
def report_lines(lines, source, a_multi_node_single_task=None): def report_lines(lines, source, a_multi_node_single_task=None):
# support colored multi-node reporting on the same Task for easier debugging # support colored multi-node reporting on the same Task for easier debugging
if lines and a_multi_node_single_task and a_multi_node_single_task > 0: if lines and a_multi_node_single_task and a_multi_node_single_task > 0:
# noinspection PyBroadException rank = self._get_node_rank()
try:
rank = int(os.environ.get("RANK", os.environ.get('SLURM_PROCID')) or 0)
except Exception:
rank = 0
if rank: if rank:
# see ANSI color: https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit # see ANSI color: https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
# Only the "RANK x: line is colored to preserve the original color reporting # Only the "RANK x: line is colored to preserve the original color reporting
@ -2873,6 +2869,10 @@ class Worker(ServiceCommandSection):
skip_freeze_update = self.is_conda and not self._session.config.get( skip_freeze_update = self.is_conda and not self._session.config.get(
"agent.package_manager.conda_full_env_update", False) "agent.package_manager.conda_full_env_update", False)
# skip update requirements on nodes that are not Rank 0 (only update requirements on RANK 0)
if self._get_node_rank():
skip_freeze_update = True
freeze = self.freeze_task_environment( freeze = self.freeze_task_environment(
task_id=current_task.id, task_id=current_task.id,
requirements_manager=requirements_manager, requirements_manager=requirements_manager,
@ -3273,41 +3273,58 @@ class Worker(ServiceCommandSection):
) )
self.log_traceback(e) self.log_traceback(e)
@staticmethod
def _get_node_rank():
# type: () -> int
# noinspection PyBroadException
try:
rank = int(os.environ.get("RANK", os.environ.get('SLURM_PROCID')) or 0)
except Exception:
rank = 0
return rank
def handle_task_process_termination(self, task_id, exit_code, session=None): def handle_task_process_termination(self, task_id, exit_code, session=None):
# type: (Text, int) -> None # type: (Text, int) -> None
session = session or self._session session = session or self._session
self.log("Task process terminated") rank = self._get_node_rank()
rank_text = " - rank {}".format(rank) if rank else ""
self.log("Task process terminated"+rank_text)
# only RANK 0 can change the Task status.
if exit_code == COMMAND_SUCCESS: if exit_code == COMMAND_SUCCESS:
self.log("Task success: completing") self.log("Task success: completing"+rank_text)
self.send_logs(task_id, ["Process completed successfully"], session=session) self.send_logs(task_id, ["Process completed successfully"+rank_text], session=session)
session.send_api( if not rank:
tasks_api.CompletedRequest( session.send_api(
task=task_id, tasks_api.CompletedRequest(
status_reason="worker execution done", task=task_id,
status_message=self._task_status_change_message, status_reason="worker execution done",
status_message=self._task_status_change_message,
)
) )
)
elif exit_code in (ExitStatus.interrupted, 256+ExitStatus.interrupted): elif exit_code in (ExitStatus.interrupted, 256+ExitStatus.interrupted):
self.log("Task interrupted: stopping") self.log("Task interrupted: stopping"+rank_text)
self.send_logs(task_id, ["Process terminated by user"], session=session) self.send_logs(task_id, ["Process terminated by user"+rank_text], session=session)
session.send_api( if not rank:
tasks_api.StoppedRequest( session.send_api(
task=task_id, tasks_api.StoppedRequest(
status_reason="user abort", task=task_id,
status_message=self._task_status_change_message, status_reason="user abort",
status_message=self._task_status_change_message,
)
) )
)
else: else:
self.log("Task failure: setting status to 'failed'") self.log("Task failure: setting status to 'failed'"+rank_text)
self.send_logs(task_id, ["Process failed, exit code {}".format(exit_code)], session=session) self.send_logs(task_id, ["Process failed, exit code {}"+rank_text.format(exit_code)], session=session)
session.send_api( if not rank:
tasks_api.FailedRequest( session.send_api(
task=task_id, tasks_api.FailedRequest(
status_reason="worker execution exit code {}".format(exit_code), task=task_id,
status_message=self._task_status_change_message, status_reason="worker execution exit code {}".format(exit_code),
status_message=self._task_status_change_message,
)
) )
)
def freeze_task_environment(self, task_id=None, requirements_manager=None, def freeze_task_environment(self, task_id=None, requirements_manager=None,
add_venv_folder_cache=None, execution_info=None, update_requirements=False): add_venv_folder_cache=None, execution_info=None, update_requirements=False):