Black formatting

This commit is contained in:
allegroai 2024-04-13 22:09:10 +03:00
parent 7887b0ac89
commit c751bea4dc

View File

@ -14,23 +14,23 @@ from six import BytesIO
default_level = logging.INFO default_level = logging.INFO
_levelToName = { _levelToName = {
logging.CRITICAL: 'CRITICAL', logging.CRITICAL: "CRITICAL",
logging.ERROR: 'ERROR', logging.ERROR: "ERROR",
logging.WARNING: 'WARNING', logging.WARNING: "WARNING",
logging.INFO: 'INFO', logging.INFO: "INFO",
logging.DEBUG: 'DEBUG', logging.DEBUG: "DEBUG",
logging.NOTSET: 'NOTSET', logging.NOTSET: "NOTSET",
} }
_nameToLevel = { _nameToLevel = {
'CRITICAL': logging.CRITICAL, "CRITICAL": logging.CRITICAL,
'FATAL': logging.FATAL, "FATAL": logging.FATAL,
'ERROR': logging.ERROR, "ERROR": logging.ERROR,
'WARN': logging.WARNING, "WARN": logging.WARNING,
'WARNING': logging.WARNING, "WARNING": logging.WARNING,
'INFO': logging.INFO, "INFO": logging.INFO,
'DEBUG': logging.DEBUG, "DEBUG": logging.DEBUG,
'NOTSET': logging.NOTSET, "NOTSET": logging.NOTSET,
} }
@ -48,7 +48,6 @@ def resolve_logging_level(level):
class PickledLogger(logging.getLoggerClass()): class PickledLogger(logging.getLoggerClass()):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(PickledLogger, self).__init__(*args, **kwargs) super(PickledLogger, self).__init__(*args, **kwargs)
self._init_kwargs = None self._init_kwargs = None
@ -59,14 +58,17 @@ class PickledLogger(logging.getLoggerClass()):
if sys.version_info.major >= 3 and sys.version_info.minor >= 7: if sys.version_info.major >= 3 and sys.version_info.minor >= 7:
return a_instance return a_instance
safe_logger = PickledLogger(name=kwargs.get('name')) safe_logger = PickledLogger(name=kwargs.get("name"))
safe_logger.__dict__ = a_instance.__dict__ safe_logger.__dict__ = a_instance.__dict__
if 'stream' in kwargs and kwargs['stream']: if "stream" in kwargs and kwargs["stream"]:
kwargs['stream'] = 'stdout' if kwargs['stream'] == sys.stdout else ( kwargs["stream"] = (
'stderr' if kwargs['stream'] == sys.stderr else kwargs['stream']) "stdout"
if kwargs["stream"] == sys.stdout
else ("stderr" if kwargs["stream"] == sys.stderr else kwargs["stream"])
)
else: else:
kwargs['stream'] = None kwargs["stream"] = None
kwargs['_func'] = func kwargs["_func"] = func
safe_logger._init_kwargs = kwargs safe_logger._init_kwargs = kwargs
return safe_logger return safe_logger
@ -74,15 +76,17 @@ class PickledLogger(logging.getLoggerClass()):
return self._init_kwargs or {} return self._init_kwargs or {}
def __setstate__(self, state): def __setstate__(self, state):
state['stream'] = sys.stdout if state['stream'] == 'stdout' else ( state["stream"] = (
sys.stderr if state['stream'] == 'stderr' else state['stream']) sys.stdout
_func = state.pop('_func') or self.__class__ if state["stream"] == "stdout"
else (sys.stderr if state["stream"] == "stderr" else state["stream"])
)
_func = state.pop("_func") or self.__class__
self.__dict__ = _func(**state).__dict__ self.__dict__ = _func(**state).__dict__
class _LevelRangeFilter(logging.Filter): class _LevelRangeFilter(logging.Filter):
def __init__(self, min_level, max_level, name=""):
def __init__(self, min_level, max_level, name=''):
super(_LevelRangeFilter, self).__init__(name) super(_LevelRangeFilter, self).__init__(name)
self.min_level = min_level self.min_level = min_level
self.max_level = max_level self.max_level = max_level
@ -103,21 +107,18 @@ class LoggerRoot(object):
if level is None and getenv("CLEARML_LOG_LEVEL"): if level is None and getenv("CLEARML_LOG_LEVEL"):
level = resolve_logging_level(getenv("CLEARML_LOG_LEVEL").strip()) level = resolve_logging_level(getenv("CLEARML_LOG_LEVEL").strip())
if level is None: if level is None:
print('Invalid value in environment variable CLEARML_LOG_LEVEL: %s' % getenv("CLEARML_LOG_LEVEL")) print("Invalid value in environment variable CLEARML_LOG_LEVEL: %s" % getenv("CLEARML_LOG_LEVEL"))
clearml_logger = logging.getLogger('clearml') clearml_logger = logging.getLogger("clearml")
if level is None: if level is None:
level = clearml_logger.level level = clearml_logger.level
# avoid nested imports # avoid nested imports
from ..config import get_log_redirect_level from ..config import get_log_redirect_level
LoggerRoot.__base_logger = PickledLogger.wrapper( LoggerRoot.__base_logger = PickledLogger.wrapper(
clearml_logger, clearml_logger, func=cls.get_base_logger, level=level, stream=stream, colored=colored
func=cls.get_base_logger,
level=level,
stream=stream,
colored=colored
) )
LoggerRoot.__base_logger.setLevel(level) LoggerRoot.__base_logger.setLevel(level)
@ -129,9 +130,7 @@ class LoggerRoot(object):
# Adjust redirect level in case requested level is higher (e.g. logger is requested for CRITICAL # Adjust redirect level in case requested level is higher (e.g. logger is requested for CRITICAL
# and redirect is set for ERROR, in which case we redirect from CRITICAL) # and redirect is set for ERROR, in which case we redirect from CRITICAL)
redirect_level = max(level, redirect_level) redirect_level = max(level, redirect_level)
LoggerRoot.__base_logger.addHandler( LoggerRoot.__base_logger.addHandler(ClearmlStreamHandler(redirect_level, sys.stderr, colored))
ClearmlStreamHandler(redirect_level, sys.stderr, colored)
)
if level < redirect_level: if level < redirect_level:
# Not all levels were redirected, remaining should be sent to requested stream # Not all levels were redirected, remaining should be sent to requested stream
@ -139,9 +138,7 @@ class LoggerRoot(object):
handler.addFilter(_LevelRangeFilter(min_level=level, max_level=redirect_level - 1)) handler.addFilter(_LevelRangeFilter(min_level=level, max_level=redirect_level - 1))
LoggerRoot.__base_logger.addHandler(handler) LoggerRoot.__base_logger.addHandler(handler)
else: else:
LoggerRoot.__base_logger.addHandler( LoggerRoot.__base_logger.addHandler(ClearmlStreamHandler(level, stream, colored))
ClearmlStreamHandler(level, stream, colored)
)
LoggerRoot.__base_logger.propagate = False LoggerRoot.__base_logger.propagate = False
return LoggerRoot.__base_logger return LoggerRoot.__base_logger
@ -157,28 +154,27 @@ class LoggerRoot(object):
# https://github.com/pytest-dev/pytest/issues/5502#issuecomment-647157873 # https://github.com/pytest-dev/pytest/issues/5502#issuecomment-647157873
loggers = [logging.getLogger()] + list(logging.Logger.manager.loggerDict.values()) loggers = [logging.getLogger()] + list(logging.Logger.manager.loggerDict.values())
for logger in loggers: for logger in loggers:
handlers = getattr(logger, 'handlers', []) handlers = getattr(logger, "handlers", [])
for handler in handlers: for handler in handlers:
if isinstance(handler, ClearmlLoggerHandler): if isinstance(handler, ClearmlLoggerHandler):
logger.removeHandler(handler) logger.removeHandler(handler)
def add_options(parser): def add_options(parser):
""" Add logging options to an argparse.ArgumentParser object """ """Add logging options to an argparse.ArgumentParser object"""
level = logging.getLevelName(default_level) level = logging.getLevelName(default_level)
parser.add_argument( parser.add_argument("--log-level", "-l", default=level, help="Log level (default is %s)" % level)
'--log-level', '-l', default=level, help='Log level (default is %s)' % level)
def apply_logging_args(args): def apply_logging_args(args):
""" Apply logging args from an argparse.ArgumentParser parsed args """ """Apply logging args from an argparse.ArgumentParser parsed args"""
global default_level global default_level
default_level = logging.getLevelName(args.log_level.upper()) default_level = logging.getLevelName(args.log_level.upper())
def get_logger(path=None, level=None, stream=None, colored=False): def get_logger(path=None, level=None, stream=None, colored=False):
""" Get a python logging object named using the provided filename and preconfigured with a color-formatted """Get a python logging object named using the provided filename and preconfigured with a color-formatted
stream handler stream handler
""" """
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -196,42 +192,42 @@ def get_logger(path=None, level=None, stream=None, colored=False):
ch.setLevel(level) ch.setLevel(level)
log.addHandler(ch) log.addHandler(ch)
log.propagate = True log.propagate = True
return PickledLogger.wrapper( return PickledLogger.wrapper(log, func=get_logger, path=path, level=level, stream=stream, colored=colored)
log, func=get_logger, path=path, level=level, stream=stream, colored=colored)
def _add_file_handler(logger, log_dir, fh, formatter=None): def _add_file_handler(logger, log_dir, fh, formatter=None):
""" Adds a file handler to a logger """ """Adds a file handler to a logger"""
Path(log_dir).mkdir(parents=True, exist_ok=True) Path(log_dir).mkdir(parents=True, exist_ok=True)
if not formatter: if not formatter:
log_format = '%(asctime)s %(name)s x_x[%(levelname)s] %(message)s' log_format = "%(asctime)s %(name)s x_x[%(levelname)s] %(message)s"
formatter = logging.Formatter(log_format) formatter = logging.Formatter(log_format)
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
def add_rotating_file_handler(logger, log_dir, log_file_prefix, max_bytes=10 * 1024 * 1024, backup_count=20, def add_rotating_file_handler(
formatter=None): logger, log_dir, log_file_prefix, max_bytes=10 * 1024 * 1024, backup_count=20, formatter=None
""" Create and add a rotating file handler to a logger """ ):
"""Create and add a rotating file handler to a logger"""
fh = ClearmlRotatingFileHandler( fh = ClearmlRotatingFileHandler(
str(Path(log_dir) / ('%s.log' % log_file_prefix)), maxBytes=max_bytes, backupCount=backup_count) str(Path(log_dir) / ("%s.log" % log_file_prefix)), maxBytes=max_bytes, backupCount=backup_count
)
_add_file_handler(logger, log_dir, fh, formatter) _add_file_handler(logger, log_dir, fh, formatter)
def add_time_rotating_file_handler(logger, log_dir, log_file_prefix, when='midnight', formatter=None): def add_time_rotating_file_handler(logger, log_dir, log_file_prefix, when="midnight", formatter=None):
""" """
Create and add a time rotating file handler to a logger. Create and add a time rotating file handler to a logger.
Possible values for when are 'midnight', weekdays ('w0'-'W6', when 0 is Monday), and 's', 'm', 'h' amd 'd' for Possible values for when are 'midnight', weekdays ('w0'-'W6', when 0 is Monday), and 's', 'm', 'h' amd 'd' for
seconds, minutes, hours and days respectively (case-insensitive) seconds, minutes, hours and days respectively (case-insensitive)
""" """
fh = ClearmlTimedRotatingFileHandler( fh = ClearmlTimedRotatingFileHandler(str(Path(log_dir) / ("%s.log" % log_file_prefix)), when=when)
str(Path(log_dir) / ('%s.log' % log_file_prefix)), when=when)
_add_file_handler(logger, log_dir, fh, formatter) _add_file_handler(logger, log_dir, fh, formatter)
def get_null_logger(name=None): def get_null_logger(name=None):
""" Get a logger with a null handler """ """Get a logger with a null handler"""
log = logging.getLogger(name if name else 'null') log = logging.getLogger(name if name else "null")
if not log.handlers: if not log.handlers:
# avoid nested imports # avoid nested imports
from ..config import config from ..config import config
@ -242,10 +238,10 @@ def get_null_logger(name=None):
class TqdmLog(object): class TqdmLog(object):
""" Tqdm (progressbar) wrapped logging class """ """Tqdm (progressbar) wrapped logging class"""
class _TqdmIO(BytesIO): class _TqdmIO(BytesIO):
""" IO wrapper class for Tqdm """ """IO wrapper class for Tqdm"""
def __init__(self, level=20, logger=None, *args, **kwargs): def __init__(self, level=20, logger=None, *args, **kwargs):
self._log = logger or get_null_logger() self._log = logger or get_null_logger()
@ -253,18 +249,24 @@ class TqdmLog(object):
BytesIO.__init__(self, *args, **kwargs) BytesIO.__init__(self, *args, **kwargs)
def write(self, buf): def write(self, buf):
self._buf = buf.strip('\r\n\t ') self._buf = buf.strip("\r\n\t ")
def flush(self): def flush(self):
self._log.log(self._level, self._buf) self._log.log(self._level, self._buf)
def __init__(self, total, desc='', log_level=20, ascii=False, logger=None, smoothing=0, mininterval=5, initial=0): def __init__(self, total, desc="", log_level=20, ascii=False, logger=None, smoothing=0, mininterval=5, initial=0):
from tqdm import tqdm from tqdm import tqdm
self._io = self._TqdmIO(level=log_level, logger=logger) self._io = self._TqdmIO(level=log_level, logger=logger)
self._tqdm = tqdm(total=total, desc=desc, file=self._io, ascii=ascii if not system() == 'Windows' else True, self._tqdm = tqdm(
smoothing=smoothing, total=total,
mininterval=mininterval, initial=initial) desc=desc,
file=self._io,
ascii=ascii if not system() == "Windows" else True,
smoothing=smoothing,
mininterval=mininterval,
initial=initial,
)
def update(self, n=None): def update(self, n=None):
if n is not None: if n is not None: