Fix a sync issue with ConrfigWrapper when loading configuration in session

Optimize load sequence by not loading configuration on session start if not provided (worse case it will be loaded by the ConfigWrapper)
This commit is contained in:
allegroai 2022-12-04 12:51:27 +02:00
parent 911a72f561
commit ec70b5113a
2 changed files with 54 additions and 29 deletions

View File

@ -33,7 +33,6 @@ from .defs import (
) )
from .request import Request, BatchRequest # noqa: F401 from .request import Request, BatchRequest # noqa: F401
from .token_manager import TokenManager from .token_manager import TokenManager
from ..config import load
from ..utils import get_http_session_with_retry, urllib_log_warning_setup from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ...debugging import get_logger from ...debugging import get_logger
from ...debugging.log import resolve_logging_level from ...debugging.log import resolve_logging_level
@ -125,18 +124,15 @@ class Session(TokenManager):
host=None, host=None,
logger=None, logger=None,
verbose=None, verbose=None,
initialize_logging=True,
config=None, config=None,
http_retries_config=None, http_retries_config=None,
**kwargs **kwargs
): ):
if config is not None: if config is not None:
self.config = config self.config = config
else: else:
self.config = load() from clearml.config import ConfigWrapper
if initialize_logging: self.config = ConfigWrapper._init()
self.config.initialize_logging()
token_expiration_threshold_sec = self.config.get( token_expiration_threshold_sec = self.config.get(
"auth.token_expiration_threshold_sec", 60 "auth.token_expiration_threshold_sec", 60
@ -233,7 +229,10 @@ class Session(TokenManager):
# update only after we have max_api # update only after we have max_api
self.__class__._sessions_created += 1 self.__class__._sessions_created += 1
self._load_vaults() if self._load_vaults():
from clearml.config import ConfigWrapper, ConfigSDKWrapper
ConfigWrapper.set_config_impl(self.config)
ConfigSDKWrapper.clear_config_impl()
self._apply_config_sections(local_logger) self._apply_config_sections(local_logger)
@ -269,6 +268,7 @@ class Session(TokenManager):
return list(retry_codes) return list(retry_codes)
def _load_vaults(self): def _load_vaults(self):
# () -> Optional[bool]
if not self.check_min_api_version("2.15") or self.feature_set == "basic": if not self.check_min_api_version("2.15") or self.feature_set == "basic":
return return
@ -297,6 +297,7 @@ class Session(TokenManager):
data = list(filter(None, map(parse, vaults))) data = list(filter(None, map(parse, vaults)))
if data: if data:
self.config.set_overrides(*data) self.config.set_overrides(*data)
return True
elif res.status_code != 404: elif res.status_code != 404:
raise Exception(res.json().get("meta", {}).get("result_msg", res.text)) raise Exception(res.json().get("meta", {}).get("result_msg", res.text))
except Exception as ex: except Exception as ex:

View File

@ -25,6 +25,7 @@ class ConfigWrapper(object):
if cls._config is None: if cls._config is None:
cls._config = load_config(Path(__file__).parent) # noqa: F405 cls._config = load_config(Path(__file__).parent) # noqa: F405
cls._config.initialize_logging() cls._config.initialize_logging()
return cls._config
@classmethod @classmethod
def get(cls, *args, **kwargs): def get(cls, *args, **kwargs):
@ -36,6 +37,14 @@ class ConfigWrapper(object):
cls._init() cls._init()
return cls._config.set_overrides(*args, **kwargs) return cls._config.set_overrides(*args, **kwargs)
@classmethod
def set_config_impl(cls, value):
try:
if issubclass(value, ConfigWrapper):
cls._config = value._config
except TypeError:
cls._config = value
class ConfigSDKWrapper(object): class ConfigSDKWrapper(object):
_config_sdk = None _config_sdk = None
@ -55,16 +64,28 @@ class ConfigSDKWrapper(object):
cls._init() cls._init()
return cls._config_sdk.set_overrides(*args, **kwargs) return cls._config_sdk.set_overrides(*args, **kwargs)
@classmethod
def clear_config_impl(cls):
cls._config_sdk = None
def deferred_config(key=None, default=Config._MISSING, transform=None, multi=None): def deferred_config(key=None, default=Config._MISSING, transform=None, multi=None):
return LazyEvalWrapper( return LazyEvalWrapper(
callback=lambda: callback=lambda: (
(ConfigSDKWrapper.get(key, default) if not multi else ConfigSDKWrapper.get(key, default)
next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None)) if not multi
else next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None)
)
if transform is None if transform is None
else (transform() if key is None else transform(ConfigSDKWrapper.get(key, default) if not multi else # noqa else (
next((ConfigSDKWrapper.get(*a) for a in multi transform()
if ConfigSDKWrapper.get(*a)), None))) if key is None
else transform(
ConfigSDKWrapper.get(key, default)
if not multi
else next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None) # noqa
)
)
) )
@ -78,9 +99,9 @@ def get_cache_dir():
cache_base_dir = Path( # noqa: F405 cache_base_dir = Path( # noqa: F405
expandvars( expandvars(
expanduser( expanduser(
CLEARML_CACHE_DIR.get() or # noqa: F405 CLEARML_CACHE_DIR.get() # noqa: F405
config.get("storage.cache.default_base_dir") or or config.get("storage.cache.default_base_dir")
DEFAULT_CACHE_DIR # noqa: F405 or DEFAULT_CACHE_DIR # noqa: F405
) )
) )
) )
@ -89,8 +110,8 @@ def get_cache_dir():
def get_offline_dir(task_id=None): def get_offline_dir(task_id=None):
if not task_id: if not task_id:
return get_cache_dir() / 'offline' return get_cache_dir() / "offline"
return get_cache_dir() / 'offline' / task_id return get_cache_dir() / "offline" / task_id
def get_config_for_bucket(base_url, extra_configurations=None): def get_config_for_bucket(base_url, extra_configurations=None):
@ -117,23 +138,24 @@ def get_log_to_backend(default=None):
def get_node_count(): def get_node_count():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK'))) mpi_world_rank = int(os.environ.get("OMPI_COMM_WORLD_NODE_RANK", os.environ.get("PMI_RANK")))
return mpi_world_rank return mpi_world_rank
except Exception: except Exception:
pass pass
# noinspection PyBroadException # noinspection PyBroadException
try: try:
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_JOB_NUM_NODES'))) mpi_rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("SLURM_JOB_NUM_NODES")))
return mpi_rank return mpi_rank
except Exception: except Exception:
pass pass
# check if we have pyTorch node/worker ID (only if torch was already imported) # check if we have pyTorch node/worker ID (only if torch was already imported)
if 'torch' in sys.modules: if "torch" in sys.modules:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from torch.utils.data.dataloader import get_worker_info # noqa from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info() worker_info = get_worker_info()
if worker_info: if worker_info:
return int(worker_info.num_workers) return int(worker_info.num_workers)
@ -148,14 +170,14 @@ def get_node_id(default=0):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK'))) mpi_world_rank = int(os.environ.get("OMPI_COMM_WORLD_NODE_RANK", os.environ.get("PMI_RANK")))
except Exception: except Exception:
mpi_world_rank = None mpi_world_rank = None
# noinspection PyBroadException # noinspection PyBroadException
try: try:
mpi_rank = int(os.environ.get( mpi_rank = int(
'OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID', os.environ.get('SLURM_NODEID'))) os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("SLURM_PROCID", os.environ.get("SLURM_NODEID")))
) )
except Exception: except Exception:
mpi_rank = None mpi_rank = None
@ -170,10 +192,11 @@ def get_node_id(default=0):
torch_rank = None torch_rank = None
# check if we have pyTorch node/worker ID (only if torch was already imported) # check if we have pyTorch node/worker ID (only if torch was already imported)
if 'torch' in sys.modules: if "torch" in sys.modules:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from torch.utils.data.dataloader import get_worker_info # noqa from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info() worker_info = get_worker_info()
if not worker_info: if not worker_info:
torch_rank = None torch_rank = None
@ -185,8 +208,9 @@ def get_node_id(default=0):
except Exception: except Exception:
# guess a number based on wid hopefully unique value # guess a number based on wid hopefully unique value
import hashlib import hashlib
h = hashlib.md5() h = hashlib.md5()
h.update(str(w_id).encode('utf-8')) h.update(str(w_id).encode("utf-8"))
torch_rank = int(h.hexdigest(), 16) torch_rank = int(h.hexdigest(), 16)
except Exception: except Exception:
torch_rank = None torch_rank = None
@ -208,7 +232,7 @@ def get_is_master_node():
def get_log_redirect_level(): def get_log_redirect_level():
""" Returns which log level (and up) should be redirected to stderr. None means no redirection. """ """Returns which log level (and up) should be redirected to stderr. None means no redirection."""
value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405 value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405
try: try:
if value: if value:
@ -225,8 +249,8 @@ def __set_is_master_node():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# pop both set the first # pop both set the first
env_a = os.environ.pop('CLEARML_FORCE_MASTER_NODE', None) env_a = os.environ.pop("CLEARML_FORCE_MASTER_NODE", None)
env_b = os.environ.pop('TRAINS_FORCE_MASTER_NODE', None) env_b = os.environ.pop("TRAINS_FORCE_MASTER_NODE", None)
force_master_node = env_a or env_b force_master_node = env_a or env_b
except Exception: except Exception:
force_master_node = None force_master_node = None