Suppress "\r" when reading a current chunk of a file. Add agent.suppress_carriage_return (default True) to support previous behavior.

This commit is contained in:
allegroai 2020-10-11 11:21:08 +03:00
parent 7cd9fa6c41
commit 15f4aa613e
2 changed files with 38 additions and 29 deletions

View File

@ -398,6 +398,7 @@ class Worker(ServiceCommandSection):
self._redirected_stdout_file_no = None self._redirected_stdout_file_no = None
self._uptime_config = self._session.config.get("agent.uptime", None) self._uptime_config = self._session.config.get("agent.uptime", None)
self._downtime_config = self._session.config.get("agent.downtime", None) self._downtime_config = self._session.config.get("agent.downtime", None)
self._suppress_cr = self._session.config.get("agent.suppress_carriage_return", True)
# True - supported # True - supported
# None - not initialized # None - not initialized
@ -984,21 +985,26 @@ class Worker(ServiceCommandSection):
**kwargs # type: Any **kwargs # type: Any
): ):
# type: (...) -> Tuple[Optional[int], TaskStopReason] # type: (...) -> Tuple[Optional[int], TaskStopReason]
def _print_file(file_path, prev_line_count): def _print_file(file_path, prev_pos=0):
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
f.seek(prev_pos)
binary_text = f.read() binary_text = f.read()
if not binary_text: pos = f.tell()
return []
# skip the previously printed lines, # skip the previously printed lines,
blines = binary_text.split(b'\n')[prev_line_count:] blines = binary_text.split(b'\n') if binary_text else []
if not blines: if not blines:
return blines return blines, pos
return decode_binary_lines(blines if blines[-1] else blines[:-1]) return (
decode_binary_lines(blines if blines[-1] else blines[:-1],
replace_cr=not self._suppress_cr,
overwrite_cr=self._suppress_cr),
pos
)
stdout = open(stdout_path, "wt") stdout = open(stdout_path, "wt")
stderr = open(stderr_path, "wt") if stderr_path else stdout stderr = open(stderr_path, "wt") if stderr_path else stdout
stdout_line_count, stdout_last_lines = 0, [] stdout_line_count, stdout_pos_count, stdout_last_lines = 0, 0, []
stderr_line_count, stderr_last_lines = 0, [] stderr_line_count, stderr_pos_count, stderr_last_lines = 0, 0, []
service_mode_internal_agent_started = None service_mode_internal_agent_started = None
stopping = False stopping = False
status = None status = None
@ -1037,7 +1043,7 @@ class Worker(ServiceCommandSection):
stderr.flush() stderr.flush()
# get diff from previous poll # get diff from previous poll
printed_lines = _print_file(stdout_path, stdout_line_count) printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
if self._services_mode and not stopping and not status: if self._services_mode and not stopping and not status:
# if the internal agent started, we stop logging, it will take over logging. # if the internal agent started, we stop logging, it will take over logging.
# if the internal agent started running the task itself, it will return status==0, # if the internal agent started running the task itself, it will return status==0,
@ -1047,13 +1053,10 @@ class Worker(ServiceCommandSection):
if status is not None: if status is not None:
stop_reason = 'Service started' stop_reason = 'Service started'
stdout_line_count += self.send_logs( stdout_line_count += self.send_logs(task_id, printed_lines)
task_id, printed_lines
)
if stderr_path: if stderr_path:
stderr_line_count += self.send_logs( printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
task_id, _print_file(stderr_path, stderr_line_count) stderr_line_count += self.send_logs(task_id, printed_lines)
)
except subprocess.CalledProcessError as ex: except subprocess.CalledProcessError as ex:
# non zero return code # non zero return code
@ -1064,9 +1067,11 @@ class Worker(ServiceCommandSection):
raise raise
except Exception: except Exception:
# we should not get here, but better safe than sorry # we should not get here, but better safe than sorry
stdout_line_count += self.send_logs(task_id, _print_file(stdout_path, stdout_line_count)) printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
stdout_line_count += self.send_logs(task_id, printed_lines)
if stderr_path: if stderr_path:
stderr_line_count += self.send_logs(task_id, _print_file(stderr_path, stderr_line_count)) printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
stderr_line_count += self.send_logs(task_id, printed_lines)
stop_reason = 'Exception occurred' stop_reason = 'Exception occurred'
status = -1 status = -1
@ -1080,13 +1085,11 @@ class Worker(ServiceCommandSection):
stderr.close() stderr.close()
# Send last lines # Send last lines
stdout_line_count += self.send_logs( printed_lines, stdout_pos_count = _print_file(stdout_path, stdout_pos_count)
task_id, _print_file(stdout_path, stdout_line_count) stdout_line_count += self.send_logs(task_id, printed_lines)
)
if stderr_path: if stderr_path:
stderr_line_count += self.send_logs( printed_lines, stderr_pos_count = _print_file(stderr_path, stderr_pos_count)
task_id, _print_file(stderr_path, stderr_line_count) stderr_line_count += self.send_logs(task_id, printed_lines)
)
return status, stop_reason return status, stop_reason

View File

@ -9,7 +9,7 @@ from attr import attrs, attrib
import six import six
from six import binary_type, text_type from six import binary_type, text_type
from trains_agent.helper.base import nonstrict_in_place_sort, create_tree from trains_agent.helper.base import nonstrict_in_place_sort
def print_text(text, newline=True): def print_text(text, newline=True):
@ -22,15 +22,21 @@ def print_text(text, newline=True):
sys.stdout.write(data) sys.stdout.write(data)
def decode_binary_lines(binary_lines, encoding='utf-8'): def decode_binary_lines(binary_lines, encoding='utf-8', replace_cr=False, overwrite_cr=False):
# decode per line, if we failed decoding skip the line # decode per line, if we failed decoding skip the line
lines = [] lines = []
for b in binary_lines: for b in binary_lines:
# noinspection PyBroadException
try: try:
l = b.decode(encoding=encoding, errors='replace').replace('\r', '\n') line = b.decode(encoding=encoding, errors='replace')
except: if replace_cr:
l = '' line = line.replace('\r', '\n')
lines.append(l + '\n' if l and l[-1] != '\n' else l) elif overwrite_cr:
cr_lines = line.split('\r')
line = cr_lines[-1] if cr_lines[-1] or len(cr_lines) < 2 else cr_lines[-2]
except Exception:
line = ''
lines.append(line + '\n' if not line or line[-1] != '\n' else line)
return lines return lines