diff --git a/trains_agent/__main__.py b/trains_agent/__main__.py index c58a8d0..1f68a3e 100644 --- a/trains_agent/__main__.py +++ b/trains_agent/__main__.py @@ -20,7 +20,9 @@ from .interface import get_parser def run_command(parser, args, command_name): debug = args.debug - if len(command_name.split('.')) < 2: + if command_name and command_name.lower() == 'config': + command_class = commands.Config + elif len(command_name.split('.')) < 2: command_class = commands.Worker elif hasattr(args, 'func') and getattr(args, 'func'): command_class = getattr(commands, command_name.capitalize()) diff --git a/trains_agent/commands/__init__.py b/trains_agent/commands/__init__.py index b1ab1dd..791ed16 100644 --- a/trains_agent/commands/__init__.py +++ b/trains_agent/commands/__init__.py @@ -1,3 +1,4 @@ from __future__ import print_function from .worker import Worker +from .check_config import Config diff --git a/trains_agent/commands/check_config.py b/trains_agent/commands/check_config.py new file mode 100644 index 0000000..2657577 --- /dev/null +++ b/trains_agent/commands/check_config.py @@ -0,0 +1,15 @@ +from trains_agent.commands.base import ServiceCommandSection + + +class Config(ServiceCommandSection): + + def __init__(self, *args, **kwargs): + super(Config, self).__init__(*args, only_load_config=True, **kwargs) + + def config(self, **_): + return self._session.print_configuration() + + @property + def service(self): + """ The name of the REST service used by this command """ + return 'config' diff --git a/trains_agent/session.py b/trains_agent/session.py index 4f6cce9..c5751b0 100644 --- a/trains_agent/session.py +++ b/trains_agent/session.py @@ -72,7 +72,11 @@ class Session(_Session): os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file if not Path(config_file).is_file(): raise ValueError("Could not open configuration file: {}".format(config_file)) - super(Session, self).__init__(*args, **kwargs) + if kwargs.get('only_load_config'): + from trains_agent.backend_api.config import load + self.config = load() + else: + super(Session, self).__init__(*args, **kwargs) self.log = self.get_logger(__name__) self.trace = kwargs.get('trace', False) self._config_file = kwargs.get('config_file') or \ @@ -120,7 +124,8 @@ class Session(_Session): if not worker_name.get(): worker_name.set(platform.node()) - self.create_cache_folders() + if not kwargs.get('only_load_config'): + self.create_cache_folders() @staticmethod def get_logger(name):