diff --git a/clearml/debugging/trace.py b/clearml/debugging/trace.py index fb4fa89b..a5570ae8 100644 --- a/clearml/debugging/trace.py +++ b/clearml/debugging/trace.py @@ -21,7 +21,7 @@ def _thread_linux_id(): def _thread_py_id(): # return threading.get_ident() - return zipfile.crc32(int(threading.get_ident()).to_bytes(8, 'little')) + return zipfile.crc32(int(threading.get_ident()).to_bytes(8, "little")) def _log_stderr(name, fnc, args, kwargs, is_return): @@ -32,20 +32,22 @@ def _log_stderr(name, fnc, args, kwargs, is_return): return if __trace_level not in (1, 2, -1, -2): return - fnc_address = str(fnc).split(' at ') - fnc_address = '{}'.format(fnc_address[-1].replace('>', '')) if len(fnc_address) > 1 else '' + fnc_address = str(fnc).split(" at ") + fnc_address = "{}".format(fnc_address[-1].replace(">", "")) if len(fnc_address) > 1 else "" if __trace_level == 1 or __trace_level == -1: - t = '{:14} {}'.format(fnc_address, name) + t = "{:14} {}".format(fnc_address, name) elif __trace_level == 2 or __trace_level == -2: - a_args = str(args)[1:-1] if args else '' - a_kwargs = ' {}'.format(kwargs) if kwargs else '' - t = '{:14} {} ({}{})'.format(fnc_address, name, a_args, a_kwargs) + a_args = str(args)[1:-1] if args else "" + a_kwargs = " {}".format(kwargs) if kwargs else "" + t = "{:14} {} ({}{})".format(fnc_address, name, a_args, a_kwargs) # get a nicer thread id h = int(__thread_id()) ts = time.time() - __trace_start - __stream_write('{}{:<9.3f}:{:5}:{:8x}: [{}] {}\n'.format( - '-' if is_return else '', ts, os.getpid(), - h, threading.current_thread().name, t)) + __stream_write( + "{}{:<9.3f}:{:5}:{:8x}: [{}] {}\n".format( + "-" if is_return else "", ts, os.getpid(), h, threading.current_thread().name, t + ) + ) if __stream_flush: __stream_flush() except Exception: @@ -64,6 +66,7 @@ def _traced_call_method(name, fnc): if r: raise r return ret + return _traced_call_int @@ -82,7 +85,7 @@ def _traced_call_cls(name, fnc): raise r return ret - return WrapperClass.__dict__['_traced_call_int'] + return WrapperClass.__dict__["_traced_call_int"] def _traced_call_static(name, fnc): @@ -99,7 +102,8 @@ def _traced_call_static(name, fnc): if r: raise r return ret - return WrapperStatic.__dict__['_traced_call_int'] + + return WrapperStatic.__dict__["_traced_call_int"] def _traced_call_func(name, fnc): @@ -114,14 +118,16 @@ def _traced_call_func(name, fnc): if r: raise r return ret + return _traced_call_int -def _patch_module(module, prefix='', basepath=None, basemodule=None, exclude_prefixes=[], only_prefix=[]): +def _patch_module(module, prefix="", basepath=None, basemodule=None, exclude_prefixes=[], only_prefix=[]): if isinstance(module, str): if basemodule is None: - basemodule = module + '.' + basemodule = module + "." import importlib + importlib.import_module(module) module = sys.modules.get(module) if not module: @@ -130,8 +136,8 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None, exclude_pre basepath = os.path.sep.join(module.__file__.split(os.path.sep)[:-1]) + os.path.sep # only sub modules - if not hasattr(module, '__file__') or (inspect.ismodule(module) and not module.__file__.startswith(basepath)): - if hasattr(module, '__module__') and module.__module__.startswith(basemodule): + if not hasattr(module, "__file__") or (inspect.ismodule(module) and not module.__file__.startswith(basepath)): + if hasattr(module, "__module__") and module.__module__.startswith(basemodule): # this is one of ours pass else: @@ -139,25 +145,25 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None, exclude_pre return # Do not patch ourselves - if hasattr(module, '__file__') and module.__file__ == __file__: + if hasattr(module, "__file__") and module.__file__ == __file__: return - prefix += module.__name__.split('.')[-1] + '.' + prefix += module.__name__.split(".")[-1] + "." # Do not patch low level network layer - if prefix.startswith('clearml.backend_api.session.') and prefix != 'clearml.backend_api.session.': - if not prefix.endswith('.Session.') and '.token_manager.' not in prefix: + if prefix.startswith("clearml.backend_api.session.") and prefix != "clearml.backend_api.session.": + if not prefix.endswith(".Session.") and ".token_manager." not in prefix: # print('SKIPPING: {}'.format(prefix)) return - if prefix.startswith('clearml.backend_api.services.'): + if prefix.startswith("clearml.backend_api.services."): return for skip in exclude_prefixes: if prefix.startswith(skip): return - for fn in (m for m in dir(module) if not m.startswith('__')): - if fn in ('schema_property') or fn.startswith('_PostImportHookPatching__'): + for fn in (m for m in dir(module) if not m.startswith("__")): + if fn in ("schema_property") or fn.startswith("_PostImportHookPatching__"): continue # noinspection PyBroadException try: @@ -165,27 +171,39 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None, exclude_pre except Exception: continue if inspect.ismodule(fnc): - _patch_module(fnc, prefix=prefix, basepath=basepath, basemodule=basemodule, - exclude_prefixes=exclude_prefixes, only_prefix=only_prefix) + _patch_module( + fnc, + prefix=prefix, + basepath=basepath, + basemodule=basemodule, + exclude_prefixes=exclude_prefixes, + only_prefix=only_prefix, + ) elif inspect.isclass(fnc): - _patch_module(fnc, prefix=prefix, basepath=basepath, basemodule=basemodule, - exclude_prefixes=exclude_prefixes, only_prefix=only_prefix) + _patch_module( + fnc, + prefix=prefix, + basepath=basepath, + basemodule=basemodule, + exclude_prefixes=exclude_prefixes, + only_prefix=only_prefix, + ) elif inspect.isroutine(fnc): - if only_prefix and all(p not in (prefix+str(fn)) for p in only_prefix): + if only_prefix and all(p not in (prefix + str(fn)) for p in only_prefix): continue for skip in exclude_prefixes: - if (prefix+str(fn)).startswith(skip): + if (prefix + str(fn)).startswith(skip): continue # _log_stderr('Patching: {}'.format(prefix+fn)) if inspect.isclass(module): # check if this is even in our module - if hasattr(fnc, '__module__') and fnc.__module__ != module.__module__: + if hasattr(fnc, "__module__") and fnc.__module__ != module.__module__: pass # print('not ours {} {}'.format(module, fnc)) - elif hasattr(fnc, '__qualname__') and fnc.__qualname__.startswith(module.__name__ + '.'): + elif hasattr(fnc, "__qualname__") and fnc.__qualname__.startswith(module.__name__ + "."): if isinstance(module.__dict__[fn], classmethod): setattr(module, fn, _traced_call_cls(prefix + fn, fnc)) elif isinstance(module.__dict__[fn], staticmethod): @@ -194,7 +212,7 @@ def _patch_module(module, prefix='', basepath=None, basemodule=None, exclude_pre setattr(module, fn, _traced_call_method(prefix + fn, fnc)) else: # probably not ours hopefully static function - if hasattr(fnc, '__qualname__') and not fnc.__qualname__.startswith(module.__name__ + '.'): + if hasattr(fnc, "__qualname__") and not fnc.__qualname__.startswith(module.__name__ + "."): pass # print('not ours {} {}'.format(module, fnc)) else: # we should not get here @@ -226,17 +244,18 @@ def trace_trains(stream=None, level=1, exclude_prefixes=[], only_prefix=[]): return __patched_trace = True if not __thread_id: - if sys.platform == 'linux': + if sys.platform == "linux": import ctypes - __thread_so = ctypes.cdll.LoadLibrary('libc.so.6') + + __thread_so = ctypes.cdll.LoadLibrary("libc.so.6") __thread_id = _thread_linux_id else: __thread_id = _thread_py_id - stderr_write = sys.stderr._original_write if hasattr(sys.stderr, '_original_write') else sys.stderr.write + stderr_write = sys.stderr._original_write if hasattr(sys.stderr, "_original_write") else sys.stderr.write if stream: if isinstance(stream, str): - stream = open(stream, 'w') + stream = open(stream, "w") __stream_write = stream.write __stream_flush = stream.flush else: @@ -244,15 +263,16 @@ def trace_trains(stream=None, level=1, exclude_prefixes=[], only_prefix=[]): __stream_flush = None from ..version import __version__ - msg = 'ClearML v{} - Starting Trace\n\n'.format(__version__) + + msg = "ClearML v{} - Starting Trace\n\n".format(__version__) # print to actual stderr stderr_write(msg) # store to stream __stream_write(msg) - __stream_write('{:9}:{:5}:{:8}: {:14}\n'.format('seconds', 'pid', 'tid', 'self')) - __stream_write('{:9}:{:5}:{:8}:{:15}\n'.format('-' * 9, '-' * 5, '-' * 8, '-' * 15)) + __stream_write("{:9}:{:5}:{:8}: {:14}\n".format("seconds", "pid", "tid", "self")) + __stream_write("{:9}:{:5}:{:8}:{:15}\n".format("-" * 9, "-" * 5, "-" * 8, "-" * 15)) __trace_start = time.time() - _patch_module('clearml', exclude_prefixes=exclude_prefixes or [], only_prefix=only_prefix or []) + _patch_module("clearml", exclude_prefixes=exclude_prefixes or [], only_prefix=only_prefix or []) def trace_level(level=1): @@ -286,25 +306,25 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi from glob import glob def hash_line(a_line): - return hash(':'.join(a_line.split(':')[1:])) + return hash(":".join(a_line.split(":")[1:])) pids = {} orphan_calls = set() print_orphans = False for fname in glob(glob_mask, recursive=False): - with open(fname, 'rt') as fd: + with open(fname, "rt") as fd: lines = fd.readlines() for line in lines: # noinspection PyBroadException try: - _, pid, tid = line.split(':')[:3] + _, pid, tid = line.split(":")[:3] pid = int(pid) except Exception: continue if specify_pids and pid not in specify_pids: continue - if line.startswith('-'): + if line.startswith("-"): print_orphans = True line = line[1:] h = hash_line(line) @@ -323,16 +343,16 @@ def print_traced_files(glob_mask, lines_per_tid=5, stream=sys.stdout, specify_pi by_time = {} for p, tids in pids.items(): for t, lines in tids.items(): - ts = float(lines[-1].split(':')[0].strip()) + 0.000001 * len(by_time) + ts = float(lines[-1].split(":")[0].strip()) + 0.000001 * len(by_time) if print_orphans: for i, line in enumerate(lines): if i > 0 and hash_line(line) in orphan_calls: - lines[i] = ' ### Orphan ### {}'.format(line) - by_time[ts] = ''.join(lines) + '\n' + lines[i] = " ### Orphan ### {}".format(line) + by_time[ts] = "".join(lines) + "\n" - out_stream = open(stream, 'w') if isinstance(stream, str) else stream + out_stream = open(stream, "w") if isinstance(stream, str) else stream for k in sorted(by_time.keys()): - out_stream.write(by_time[k] + '\n') + out_stream.write(by_time[k] + "\n") if isinstance(stream, str): out_stream.close() @@ -345,11 +365,11 @@ def end_of_program(): def stdout_print(*args, **kwargs): if len(args) == 1 and not kwargs: line = str(args[0]) - if not line.endswith('\n'): - line += '\n' + if not line.endswith("\n"): + line += "\n" else: - line = '{} {}\n'.format(args or '', kwargs or '') - if hasattr(sys.stdout, '_original_write'): + line = "{} {}\n".format(args or "", kwargs or "") + if hasattr(sys.stdout, "_original_write"): sys.stdout._original_write(line) else: sys.stdout.write(line) @@ -361,16 +381,18 @@ def debug_print(*args, **kwargs): Example: [pid=123, t=0.003] message here """ global tic - tic = globals().get('tic', time.time()) + tic = globals().get("tic", time.time()) stdout_print( - "\033[1;33m[pid={}, t={:.04f}] ".format(os.getpid(), time.time()-tic) - + str(args[0] if len(args) == 1 else ("" if not args else args)) + "\033[0m", **kwargs - ) + "\033[1;33m[pid={}, t={:.04f}] ".format(os.getpid(), time.time() - tic) + + str(args[0] if len(args) == 1 else ("" if not args else args)) + + "\033[0m", + **kwargs + ) tic = time.time() -if __name__ == '__main__': +if __name__ == "__main__": # from clearml import Task # task = Task.init(project_name="examples", task_name="trace test") # trace_trains('_trace.txt', level=2) - print_traced_files('_trace_*.txt', lines_per_tid=10) + print_traced_files("_trace_*.txt", lines_per_tid=10)