Fix CLEARML_MULTI_NODE_SINGLE_TASK should be read once not every reported line

This commit is contained in:
allegroai 2024-07-24 17:45:02 +03:00
parent b2a4bf08ac
commit 2f0553b873

View File

@ -2081,9 +2081,9 @@ class Worker(ServiceCommandSection):
stderr_line_count, stderr_pos_count, stderr_last_lines = 0, 0, []
lines_buffer = defaultdict(list)
def report_lines(lines, source):
def report_lines(lines, source, a_multi_node_single_task=None):
# support colored multi-node reporting on the same Task for easier debugging
if lines and ENV_MULTI_NODE_SINGLE_TASK.get() and ENV_MULTI_NODE_SINGLE_TASK.get() > 0:
if lines and a_multi_node_single_task and a_multi_node_single_task > 0:
# noinspection PyBroadException
try:
rank = int(os.environ.get("RANK", os.environ.get('SLURM_PROCID')) or 0)
@ -2111,6 +2111,7 @@ class Worker(ServiceCommandSection):
status = None
process = None
last_task_ping = 0
multi_node_single_task = ENV_MULTI_NODE_SINGLE_TASK.get()
try:
_last_machine_update_ts = time()
stop_reason = None
@ -2169,11 +2170,11 @@ class Worker(ServiceCommandSection):
if status is not None:
stop_reason = 'Service started'
stdout_line_count += report_lines(printed_lines, "stdout")
stdout_line_count += report_lines(printed_lines, "stdout", multi_node_single_task)
if stderr_path:
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
stderr_line_count += report_lines(printed_lines, "stderr")
stderr_line_count += report_lines(printed_lines, "stderr", multi_node_single_task)
except subprocess.CalledProcessError as ex:
# non zero return code
@ -2187,10 +2188,10 @@ class Worker(ServiceCommandSection):
except Exception:
# we should not get here, but better safe than sorry
printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
stdout_line_count += report_lines(printed_lines, "stdout")
stdout_line_count += report_lines(printed_lines, "stdout", multi_node_single_task)
if stderr_path:
printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
stderr_line_count += report_lines(printed_lines, "stderr")
stderr_line_count += report_lines(printed_lines, "stderr", multi_node_single_task)
stop_reason = TaskStopReason.exception
status = -1