Fix internal logging.Logger can't be pickled

This commit is contained in:
allegroai 2021-07-23 16:02:38 +03:00
parent f9a6139168
commit 94829dd199

View File

@ -12,6 +12,35 @@ from six import BytesIO
default_level = logging.INFO default_level = logging.INFO
class PickledLogger(logging.getLoggerClass()):
def __init__(self, *args, **kwargs):
super(PickledLogger, self).__init__(*args, **kwargs)
self._init_kwargs = None
@staticmethod
def wrapper(a_instance, func, **kwargs):
safe_logger = PickledLogger(name=kwargs.get('name'))
safe_logger.__dict__ = a_instance.__dict__
if 'stream' in kwargs and kwargs['stream']:
kwargs['stream'] = 'stdout' if kwargs['stream'] == sys.stdout else (
'stderr' if kwargs['stream'] == sys.stderr else kwargs['stream'])
else:
kwargs['stream'] = None
kwargs['_func'] = func
safe_logger._init_kwargs = kwargs
return safe_logger
def __getstate__(self):
return self._init_kwargs or {}
def __setstate__(self, state):
state['stream'] = sys.stdout if state['stream'] == 'stdout' else (
sys.stderr if state['stream'] == 'stderr' else state['stream'])
_func = state.pop('_func') or self.__class__
self.__dict__ = _func(**state).__dict__
class _LevelRangeFilter(logging.Filter): class _LevelRangeFilter(logging.Filter):
def __init__(self, min_level, max_level, name=''): def __init__(self, min_level, max_level, name=''):
@ -55,8 +84,10 @@ class LoggerRoot(object):
return LoggerRoot.__base_logger return LoggerRoot.__base_logger
# avoid nested imports # avoid nested imports
from ..config import get_log_redirect_level from ..config import get_log_redirect_level
LoggerRoot.__base_logger = PickledLogger.wrapper(
LoggerRoot.__base_logger = logging.getLogger('clearml') logging.getLogger('clearml'),
func=cls.get_base_logger,
level=level, stream=stream, colored=colored)
level = level if level is not None else default_level level = level if level is not None else default_level
LoggerRoot.__base_logger.setLevel(level) LoggerRoot.__base_logger.setLevel(level)
@ -123,7 +154,8 @@ def get_logger(path=None, level=None, stream=None, colored=False):
if level is not None: if level is not None:
ch.setLevel(level) ch.setLevel(level)
log.propagate = True log.propagate = True
return log return PickledLogger.wrapper(
log, func=get_logger, path=path, level=level, stream=stream, colored=colored)
def _add_file_handler(logger, log_dir, fh, formatter=None): def _add_file_handler(logger, log_dir, fh, formatter=None):
@ -164,7 +196,7 @@ def get_null_logger(name=None):
log.addHandler(logging.NullHandler()) log.addHandler(logging.NullHandler())
log.propagate = config.get("log.null_log_propagate", False) log.propagate = config.get("log.null_log_propagate", False)
return log return PickledLogger.wrapper(log, func=get_null_logger, name=name)
class TqdmLog(object): class TqdmLog(object):