mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
Fix a sync issue with ConrfigWrapper when loading configuration in session
Optimize load sequence by not loading configuration on session start if not provided (worse case it will be loaded by the ConfigWrapper)
This commit is contained in:
parent
911a72f561
commit
ec70b5113a
@ -33,7 +33,6 @@ from .defs import (
|
||||
)
|
||||
from .request import Request, BatchRequest # noqa: F401
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||
from ...debugging import get_logger
|
||||
from ...debugging.log import resolve_logging_level
|
||||
@ -125,18 +124,15 @@ class Session(TokenManager):
|
||||
host=None,
|
||||
logger=None,
|
||||
verbose=None,
|
||||
initialize_logging=True,
|
||||
config=None,
|
||||
http_retries_config=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if config is not None:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = load()
|
||||
if initialize_logging:
|
||||
self.config.initialize_logging()
|
||||
from clearml.config import ConfigWrapper
|
||||
self.config = ConfigWrapper._init()
|
||||
|
||||
token_expiration_threshold_sec = self.config.get(
|
||||
"auth.token_expiration_threshold_sec", 60
|
||||
@ -233,7 +229,10 @@ class Session(TokenManager):
|
||||
# update only after we have max_api
|
||||
self.__class__._sessions_created += 1
|
||||
|
||||
self._load_vaults()
|
||||
if self._load_vaults():
|
||||
from clearml.config import ConfigWrapper, ConfigSDKWrapper
|
||||
ConfigWrapper.set_config_impl(self.config)
|
||||
ConfigSDKWrapper.clear_config_impl()
|
||||
|
||||
self._apply_config_sections(local_logger)
|
||||
|
||||
@ -269,6 +268,7 @@ class Session(TokenManager):
|
||||
return list(retry_codes)
|
||||
|
||||
def _load_vaults(self):
|
||||
# () -> Optional[bool]
|
||||
if not self.check_min_api_version("2.15") or self.feature_set == "basic":
|
||||
return
|
||||
|
||||
@ -297,6 +297,7 @@ class Session(TokenManager):
|
||||
data = list(filter(None, map(parse, vaults)))
|
||||
if data:
|
||||
self.config.set_overrides(*data)
|
||||
return True
|
||||
elif res.status_code != 404:
|
||||
raise Exception(res.json().get("meta", {}).get("result_msg", res.text))
|
||||
except Exception as ex:
|
||||
|
@ -25,6 +25,7 @@ class ConfigWrapper(object):
|
||||
if cls._config is None:
|
||||
cls._config = load_config(Path(__file__).parent) # noqa: F405
|
||||
cls._config.initialize_logging()
|
||||
return cls._config
|
||||
|
||||
@classmethod
|
||||
def get(cls, *args, **kwargs):
|
||||
@ -36,6 +37,14 @@ class ConfigWrapper(object):
|
||||
cls._init()
|
||||
return cls._config.set_overrides(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_config_impl(cls, value):
|
||||
try:
|
||||
if issubclass(value, ConfigWrapper):
|
||||
cls._config = value._config
|
||||
except TypeError:
|
||||
cls._config = value
|
||||
|
||||
|
||||
class ConfigSDKWrapper(object):
|
||||
_config_sdk = None
|
||||
@ -55,16 +64,28 @@ class ConfigSDKWrapper(object):
|
||||
cls._init()
|
||||
return cls._config_sdk.set_overrides(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def clear_config_impl(cls):
|
||||
cls._config_sdk = None
|
||||
|
||||
|
||||
def deferred_config(key=None, default=Config._MISSING, transform=None, multi=None):
|
||||
return LazyEvalWrapper(
|
||||
callback=lambda:
|
||||
(ConfigSDKWrapper.get(key, default) if not multi else
|
||||
next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None))
|
||||
callback=lambda: (
|
||||
ConfigSDKWrapper.get(key, default)
|
||||
if not multi
|
||||
else next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None)
|
||||
)
|
||||
if transform is None
|
||||
else (transform() if key is None else transform(ConfigSDKWrapper.get(key, default) if not multi else # noqa
|
||||
next((ConfigSDKWrapper.get(*a) for a in multi
|
||||
if ConfigSDKWrapper.get(*a)), None)))
|
||||
else (
|
||||
transform()
|
||||
if key is None
|
||||
else transform(
|
||||
ConfigSDKWrapper.get(key, default)
|
||||
if not multi
|
||||
else next((ConfigSDKWrapper.get(*a) for a in multi if ConfigSDKWrapper.get(*a)), None) # noqa
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -78,9 +99,9 @@ def get_cache_dir():
|
||||
cache_base_dir = Path( # noqa: F405
|
||||
expandvars(
|
||||
expanduser(
|
||||
CLEARML_CACHE_DIR.get() or # noqa: F405
|
||||
config.get("storage.cache.default_base_dir") or
|
||||
DEFAULT_CACHE_DIR # noqa: F405
|
||||
CLEARML_CACHE_DIR.get() # noqa: F405
|
||||
or config.get("storage.cache.default_base_dir")
|
||||
or DEFAULT_CACHE_DIR # noqa: F405
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -89,8 +110,8 @@ def get_cache_dir():
|
||||
|
||||
def get_offline_dir(task_id=None):
|
||||
if not task_id:
|
||||
return get_cache_dir() / 'offline'
|
||||
return get_cache_dir() / 'offline' / task_id
|
||||
return get_cache_dir() / "offline"
|
||||
return get_cache_dir() / "offline" / task_id
|
||||
|
||||
|
||||
def get_config_for_bucket(base_url, extra_configurations=None):
|
||||
@ -117,23 +138,24 @@ def get_log_to_backend(default=None):
|
||||
def get_node_count():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
|
||||
mpi_world_rank = int(os.environ.get("OMPI_COMM_WORLD_NODE_RANK", os.environ.get("PMI_RANK")))
|
||||
return mpi_world_rank
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_JOB_NUM_NODES')))
|
||||
mpi_rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("SLURM_JOB_NUM_NODES")))
|
||||
return mpi_rank
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# check if we have pyTorch node/worker ID (only if torch was already imported)
|
||||
if 'torch' in sys.modules:
|
||||
if "torch" in sys.modules:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from torch.utils.data.dataloader import get_worker_info # noqa
|
||||
|
||||
worker_info = get_worker_info()
|
||||
if worker_info:
|
||||
return int(worker_info.num_workers)
|
||||
@ -148,14 +170,14 @@ def get_node_id(default=0):
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
|
||||
mpi_world_rank = int(os.environ.get("OMPI_COMM_WORLD_NODE_RANK", os.environ.get("PMI_RANK")))
|
||||
except Exception:
|
||||
mpi_world_rank = None
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
mpi_rank = int(os.environ.get(
|
||||
'OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID', os.environ.get('SLURM_NODEID')))
|
||||
mpi_rank = int(
|
||||
os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("SLURM_PROCID", os.environ.get("SLURM_NODEID")))
|
||||
)
|
||||
except Exception:
|
||||
mpi_rank = None
|
||||
@ -170,10 +192,11 @@ def get_node_id(default=0):
|
||||
|
||||
torch_rank = None
|
||||
# check if we have pyTorch node/worker ID (only if torch was already imported)
|
||||
if 'torch' in sys.modules:
|
||||
if "torch" in sys.modules:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from torch.utils.data.dataloader import get_worker_info # noqa
|
||||
|
||||
worker_info = get_worker_info()
|
||||
if not worker_info:
|
||||
torch_rank = None
|
||||
@ -185,8 +208,9 @@ def get_node_id(default=0):
|
||||
except Exception:
|
||||
# guess a number based on wid hopefully unique value
|
||||
import hashlib
|
||||
|
||||
h = hashlib.md5()
|
||||
h.update(str(w_id).encode('utf-8'))
|
||||
h.update(str(w_id).encode("utf-8"))
|
||||
torch_rank = int(h.hexdigest(), 16)
|
||||
except Exception:
|
||||
torch_rank = None
|
||||
@ -208,7 +232,7 @@ def get_is_master_node():
|
||||
|
||||
|
||||
def get_log_redirect_level():
|
||||
""" Returns which log level (and up) should be redirected to stderr. None means no redirection. """
|
||||
"""Returns which log level (and up) should be redirected to stderr. None means no redirection."""
|
||||
value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405
|
||||
try:
|
||||
if value:
|
||||
@ -225,8 +249,8 @@ def __set_is_master_node():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# pop both set the first
|
||||
env_a = os.environ.pop('CLEARML_FORCE_MASTER_NODE', None)
|
||||
env_b = os.environ.pop('TRAINS_FORCE_MASTER_NODE', None)
|
||||
env_a = os.environ.pop("CLEARML_FORCE_MASTER_NODE", None)
|
||||
env_b = os.environ.pop("TRAINS_FORCE_MASTER_NODE", None)
|
||||
force_master_node = env_a or env_b
|
||||
except Exception:
|
||||
force_master_node = None
|
||||
|
Loading…
Reference in New Issue
Block a user