diff --git a/trains/backend_config/entry.py b/trains/backend_config/entry.py index 1f2440f0..92435429 100644 --- a/trains/backend_config/entry.py +++ b/trains/backend_config/entry.py @@ -101,3 +101,7 @@ class Entry(object): def error(self, message): # type: (Text) -> None pass + + def exists(self): + # type: () -> bool + return any(key for key in self.keys if self._get(key) is not NotSet) diff --git a/trains/backend_config/environment.py b/trains/backend_config/environment.py index 30ca80ba..28d3803b 100644 --- a/trains/backend_config/environment.py +++ b/trains/backend_config/environment.py @@ -23,3 +23,6 @@ class EnvEntry(Entry): def error(self, message): print("Environment configuration: {}".format(message)) + + def exists(self): + return any(key for key in self.keys if getenv(key) is not None) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index ac3785c8..175d44a5 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -24,7 +24,7 @@ from ..setupuploadmixin import SetupUploadMixin from ..util import make_message, get_or_create_project, get_single_result, \ exact_match_regex from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \ - running_remotely, get_cache_dir + running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR from ...debugging import get_logger from ...debugging.log import LoggerRoot from ...storage import StorageHelper @@ -671,6 +671,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): execution.model_labels = enumeration self._edit(execution=execution) + def _set_default_docker_image(self): + if not DOCKER_IMAGE_ENV_VAR.exists(): + return + with self._edit_lock: + self.reload() + execution = self.data.execution + execution.docker_cmd = DOCKER_IMAGE_ENV_VAR.get(default="") + self._edit(execution=execution) + def set_artifacts(self, artifacts_list=None): """ List of artifacts (tasks.Artifact) to update the task diff --git a/trains/config/defs.py b/trains/config/defs.py index acc1f55c..23d6e52b 100644 --- a/trains/config/defs.py +++ b/trains/config/defs.py @@ -8,6 +8,7 @@ SESSION_CACHE_FILE = ".session.json" DEFAULT_CACHE_DIR = str(Path(tempfile.gettempdir()) / "trains_cache") TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID") +DOCKER_IMAGE_ENV_VAR = EnvEntry("TRAINS_DOCKER_IMAGE", "ALG_DOCKER_IMAGE") LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool) NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int) PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=int)