diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index fcbc0820..4847dc43 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -3,7 +3,7 @@ import itertools import logging import os from enum import Enum -from threading import Thread +from tempfile import gettempdir from multiprocessing import RLock try: @@ -14,7 +14,7 @@ except ImportError: import six from six.moves.urllib.parse import quote -from ...backend_interface.task.repo.scriptinfo import ScriptRequirements +from ...utilities.locks import RLock as FileRLock from ...backend_interface.task.development.worker import DevWorker from ...backend_api import Session from ...backend_api.services import tasks, models, events, projects @@ -81,7 +81,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): :type force_create: bool """ task_id = self._resolve_task_id(task_id, log=log) if not force_create else None - self._edit_lock = RLock() + self.__edit_lock = None super(Task, self).__init__(id=task_id, session=session, log=log) self._project_name = None self._storage_uri = None @@ -596,26 +596,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): update = kwargs.pop('__update', False) - parameters = dict() if not update else self.get_parameters() - parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args)) - parameters.update(kwargs) - - not_allowed = { - k: type(v).__name__ - for k, v in parameters.items() - if not isinstance(v, self._parameters_allowed_types) - } - if not_allowed: - raise ValueError( - "Only builtin types ({}) are allowed for values (got {})".format( - ', '.join(t.__name__ for t in self._parameters_allowed_types), - ', '.join('%s=>%s' % p for p in not_allowed.items())), - ) - - # force cast all variables to strings (so that we can later edit them in UI) - parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} - with self._edit_lock: + self.reload() + if update: + parameters = self.get_parameters() + else: + parameters = dict() + parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args)) + parameters.update(kwargs) + + not_allowed = { + k: type(v).__name__ + for k, v in parameters.items() + if not isinstance(v, self._parameters_allowed_types) + } + if not_allowed: + raise ValueError( + "Only builtin types ({}) are allowed for values (got {})".format( + ', '.join(t.__name__ for t in self._parameters_allowed_types), + ', '.join('%s=>%s' % p for p in not_allowed.items())), + ) + + # force cast all variables to strings (so that we can later edit them in UI) + parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} + execution = self.data.execution if execution is None: execution = tasks.Execution(parameters=parameters) @@ -631,9 +635,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): :param value: Parameter value :param description: Parameter description (unused for now) """ - params = self.get_parameters() - params[name] = value - self.set_parameters(params) + self.set_parameters({name: value}, __update=True) def get_parameter(self, name, default=None): """ diff --git a/trains/config/defs.py b/trains/config/defs.py index 23d6e52b..c60faa1c 100644 --- a/trains/config/defs.py +++ b/trains/config/defs.py @@ -11,7 +11,7 @@ TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID") DOCKER_IMAGE_ENV_VAR = EnvEntry("TRAINS_DOCKER_IMAGE", "ALG_DOCKER_IMAGE") LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool) NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int) -PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=int) +PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=str) LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LOG_STDERR_REDIRECT_LEVEL") DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME") DEV_TASK_NO_REUSE = EnvEntry("TRAINS_TASK_NO_REUSE", "ALG_TASK_NO_REUSE", type=bool) diff --git a/trains/task.py b/trains/task.py index fc5768d9..9d8867b7 100644 --- a/trains/task.py +++ b/trains/task.py @@ -198,10 +198,10 @@ class Task(_Task): if cls.__main_task is not None: # if this is a subprocess, regardless of what the init was called for, # we have to fix the main task hooks and stdout bindings - if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): + if cls.__forked_proc_main_pid != os.getpid() and cls.__is_subprocess(): if task_type is None: task_type = cls.__main_task.task_type - # make sure we only do it once per process + # make sure we only do it once per process cls.__forked_proc_main_pid = os.getpid() # make sure we do not wait for the repo detect thread cls.__main_task._detect_repo_async_thread = None @@ -223,7 +223,7 @@ class Task(_Task): # check that we are not a child process, in that case do nothing. # we should not get here unless this is Windows platform, all others support fork - if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): + if cls.__is_subprocess(): class _TaskStub(object): def __call__(self, *args, **kwargs): return self @@ -234,9 +234,14 @@ class Task(_Task): def __setattr__(self, attr, val): pass - return _TaskStub() - # set us as master process - PROC_MASTER_ID_ENV_VAR.set(os.getpid()) + is_sub_process_task_id = cls.__get_master_id_task_id() + # we could not find a task ID, revert to old stub behaviour + if not is_sub_process_task_id: + return _TaskStub() + else: + # set us as master process (without task ID) + cls.__update_master_pid_task() + is_sub_process_task_id = None if task_type is None: # Backwards compatibility: if called from Task.current_task and task_type @@ -252,24 +257,41 @@ class Task(_Task): try: if not running_remotely(): - task = cls._create_dev_task( - project_name, - task_name, - task_type, - reuse_last_task_id, - ) - if output_uri: - task.output_uri = output_uri - elif cls.__default_output_uri: - task.output_uri = cls.__default_output_uri + # if this is the main process, create the task + if not is_sub_process_task_id: + task = cls._create_dev_task( + project_name, + task_name, + task_type, + reuse_last_task_id, + detect_repo=False if (isinstance(auto_connect_frameworks, dict) and + not auto_connect_frameworks.get('detect_repository', True)) else True + ) + # set defaults + if output_uri: + task.output_uri = output_uri + elif cls.__default_output_uri: + task.output_uri = cls.__default_output_uri + # store new task ID + cls.__update_master_pid_task(task=task) + else: + # subprocess should get back the task info + task = Task.get_task(task_id=is_sub_process_task_id) else: - task = cls( - private=cls.__create_protection, - task_id=get_remote_task_id(), - log_to_backend=False, - ) - if cls.__default_output_uri and not task.output_uri: - task.output_uri = cls.__default_output_uri + # if this is the main process, create the task + if not is_sub_process_task_id: + task = cls( + private=cls.__create_protection, + task_id=get_remote_task_id(), + log_to_backend=False, + ) + if cls.__default_output_uri and not task.output_uri: + task.output_uri = cls.__default_output_uri + # store new task ID + cls.__update_master_pid_task(task=task) + else: + # subprocess should get back the task info + task = Task.get_task(task_id=is_sub_process_task_id) except Exception: raise else: @@ -291,7 +313,7 @@ class Task(_Task): PatchPyTorchModelIO.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): PatchXGBoostModelIO.update_current_task(task) - if auto_resource_monitoring: + if auto_resource_monitoring and not is_sub_process_task_id: task._resource_monitor = ResourceMonitor(task) task._resource_monitor.start() @@ -319,13 +341,14 @@ class Task(_Task): # The logger will automatically take care of all patching (we just need to make sure to initialize it) logger = task.get_logger() # show the debug metrics page in the log, it is very convenient - logger.report_text( - 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( - task._get_app_server(), - task.project if task.project is not None else '*', - task.id, - ), - ) + if not is_sub_process_task_id: + logger.report_text( + 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( + task._get_app_server(), + task.project if task.project is not None else '*', + task.id, + ), + ) # Make sure we start the dev worker if required, otherwise it will only be started when we write # something to the log. task._dev_mode_task_start() @@ -1344,7 +1367,7 @@ class Task(_Task): if self._at_exit_called: return - is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid() + is_sub_process = self.__is_subprocess() # noinspection PyBroadException try: @@ -1371,11 +1394,12 @@ class Task(_Task): task_status = ('stopped', ) # wait for repository detection (if we didn't crash) - if not is_sub_process and wait_for_uploads and self._logger: + if wait_for_uploads and self._logger: # we should print summary here self._summary_artifacts() # make sure that if we crashed the thread we are not waiting forever - self._wait_for_repo_detection(timeout=10.) + if not is_sub_process: + self._wait_for_repo_detection(timeout=10.) # wait for uploads print_done_waiting = False @@ -1399,11 +1423,11 @@ class Task(_Task): elif self._logger: self._logger._flush_stdout_handler() - if not is_sub_process: - # from here, do not check worker status - if self._dev_worker: - self._dev_worker.unregister() + # from here, do not check worker status + if self._dev_worker: + self._dev_worker.unregister() + if not is_sub_process: # change task status if not task_status: pass