diff --git a/trains_agent/session.py b/trains_agent/session.py index 0cbeb23..f1122a8 100644 --- a/trains_agent/session.py +++ b/trains_agent/session.py @@ -78,7 +78,12 @@ class Session(_Session): os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = 'none' if kwargs.get('gpus') and not os.environ.get('KUBERNETES_SERVICE_HOST') \ and not os.environ.get('KUBERNETES_PORT'): - os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus') + # CUDA_VISIBLE_DEVICES does not support 'all' + if kwargs.get('gpus') == 'all': + os.environ.pop('CUDA_VISIBLE_DEVICES', None) + os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus') + else: + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus') if kwargs.get('only_load_config'): from trains_agent.backend_api.config import load self.config = load()