Fix a sync issue with ConrfigWrapper when loading configuration in session

Optimize load sequence by not loading configuration on session start if not provided (worse case it will be loaded by the ConfigWrapper)
This commit is contained in:
allegroai 2022-12-04 12:51:27 +02:00
parent 911a72f561
commit ec70b5113a
2 changed files with 54 additions and 29 deletions

View File

@ -33,7 +33,6 @@ from .defs import (
)
from .request import Request, BatchRequest # noqa: F401
from .token_manager import TokenManager
from ..config import load
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ...debugging import get_logger
from ...debugging.log import resolve_logging_level
@ -125,18 +124,15 @@ class Session(TokenManager):
host=None,
logger=None,
verbose=None,
initialize_logging=True,
config=None,
http_retries_config=None,
**kwargs
):
if config is not None:
self.config = config
else:
self.config = load()
if initialize_logging:
self.config.initialize_logging()
from clearml.config import ConfigWrapper
self.config = ConfigWrapper._init()
token_expiration_threshold_sec = self.config.get(
"auth.token_expiration_threshold_sec", 60
@ -233,7 +229,10 @@ class Session(TokenManager):
# update only after we have max_api
self.__class__._sessions_created += 1
self._load_vaults()
if self._load_vaults():
from clearml.config import ConfigWrapper, ConfigSDKWrapper
ConfigWrapper.set_config_impl(self.config)
ConfigSDKWrapper.clear_config_impl()
self._apply_config_sections(local_logger)
@ -269,6 +268,7 @@ class Session(TokenManager):
return list(retry_codes)
def _load_vaults(self):
# () -> Optional[bool]
if not self.check_min_api_version("2.15") or self.feature_set == "basic":
return
@ -297,6 +297,7 @@ class Session(TokenManager):
data = list(filter(None, map(parse, vaults)))
if data:
self.config.set_overrides(*data)
return True
elif res.status_code != 404:
raise Exception(res.json().get("meta", {}).get("result_msg", res.text))
except Exception as ex:

View File

@ -25,6 +25,7 @@ class ConfigWrapper(object):
if cls._config is None:
cls._config = load_config(Path(__file__).parent) # noqa: F405
cls._config.initialize_logging()
return cls._config
@classmethod
def get(cls, *args, **kwargs):
@ -36,6 +37,14 @@ class ConfigWrapper(object):
cls._init()
return cls._config.set_overrides(*args, **kwargs)
@classmethod
def set_config_impl(cls, value):
try:
if issubclass(value, ConfigWrapper):
cls._config = value._config
except TypeError:
cls._config = value
class ConfigSDKWrapper(object):
_config_sdk = None
@ -55,16 +64,28 @@ class ConfigSDKWrapper(object):
cls._init()
return cls._config_sdk.set_overrides(*args, **kwargs)
@classmethod
def clear_config_impl(cls):
cls._config_sdk = None
def deferred_config(key=None, default=Config._MISSING, transform=None, multi=None):
return LazyEvalWrapper(
callback=lambda:
(ConfigSDKWrapper.get(key, default) if not multi else
next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None))
callback=lambda: (
ConfigSDKWrapper.get(key, default)
if not multi
else next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None)
)
if transform is None
else (transform() if key is None else transform(ConfigSDKWrapper.get(key, default) if not multi else # noqa
next((ConfigSDKWrapper.get(*a) for a in multi
if ConfigSDKWrapper.get(*a)), None)))
else (
transform()
if key is None
else transform(
ConfigSDKWrapper.get(key, default)
if not multi
else next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None) # noqa
)
)
)
@ -78,9 +99,9 @@ def get_cache_dir():
cache_base_dir = Path( # noqa: F405
expandvars(
expanduser(
CLEARML_CACHE_DIR.get() or # noqa: F405
config.get("storage.cache.default_base_dir") or
DEFAULT_CACHE_DIR # noqa: F405
CLEARML_CACHE_DIR.get() # noqa: F405
or config.get("storage.cache.default_base_dir")
or DEFAULT_CACHE_DIR # noqa: F405
)
)
)
@ -89,8 +110,8 @@ def get_cache_dir():
def get_offline_dir(task_id=None):
if not task_id:
return get_cache_dir() / 'offline'
return get_cache_dir() / 'offline' / task_id
return get_cache_dir() / "offline"
return get_cache_dir() / "offline" / task_id
def get_config_for_bucket(base_url, extra_configurations=None):
@ -117,23 +138,24 @@ def get_log_to_backend(default=None):
def get_node_count():
# noinspection PyBroadException
try:
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
mpi_world_rank = int(os.environ.get("OMPI_COMM_WORLD_NODE_RANK", os.environ.get("PMI_RANK")))
return mpi_world_rank
except Exception:
pass
# noinspection PyBroadException
try:
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_JOB_NUM_NODES')))
mpi_rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("SLURM_JOB_NUM_NODES")))
return mpi_rank
except Exception:
pass
# check if we have pyTorch node/worker ID (only if torch was already imported)
if 'torch' in sys.modules:
if "torch" in sys.modules:
# noinspection PyBroadException
try:
from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info()
if worker_info:
return int(worker_info.num_workers)
@ -148,14 +170,14 @@ def get_node_id(default=0):
# noinspection PyBroadException
try:
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
mpi_world_rank = int(os.environ.get("OMPI_COMM_WORLD_NODE_RANK", os.environ.get("PMI_RANK")))
except Exception:
mpi_world_rank = None
# noinspection PyBroadException
try:
mpi_rank = int(os.environ.get(
'OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID', os.environ.get('SLURM_NODEID')))
mpi_rank = int(
os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("SLURM_PROCID", os.environ.get("SLURM_NODEID")))
)
except Exception:
mpi_rank = None
@ -170,10 +192,11 @@ def get_node_id(default=0):
torch_rank = None
# check if we have pyTorch node/worker ID (only if torch was already imported)
if 'torch' in sys.modules:
if "torch" in sys.modules:
# noinspection PyBroadException
try:
from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info()
if not worker_info:
torch_rank = None
@ -185,8 +208,9 @@ def get_node_id(default=0):
except Exception:
# guess a number based on wid hopefully unique value
import hashlib
h = hashlib.md5()
h.update(str(w_id).encode('utf-8'))
h.update(str(w_id).encode("utf-8"))
torch_rank = int(h.hexdigest(), 16)
except Exception:
torch_rank = None
@ -208,7 +232,7 @@ def get_is_master_node():
def get_log_redirect_level():
""" Returns which log level (and up) should be redirected to stderr. None means no redirection. """
"""Returns which log level (and up) should be redirected to stderr. None means no redirection."""
value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405
try:
if value:
@ -225,8 +249,8 @@ def __set_is_master_node():
# noinspection PyBroadException
try:
# pop both set the first
env_a = os.environ.pop('CLEARML_FORCE_MASTER_NODE', None)
env_b = os.environ.pop('TRAINS_FORCE_MASTER_NODE', None)
env_a = os.environ.pop("CLEARML_FORCE_MASTER_NODE", None)
env_b = os.environ.pop("TRAINS_FORCE_MASTER_NODE", None)
force_master_node = env_a or env_b
except Exception:
force_master_node = None