import os import six from ..config import TASK_LOG_ENVIRONMENT, running_remotely class EnvironmentBind(object): _task = None @classmethod def update_current_task(cls, current_task): cls._task = current_task # noinspection PyBroadException try: cls._bind_environment() except Exception: pass @classmethod def _bind_environment(cls): if not cls._task: return environ_log = str(TASK_LOG_ENVIRONMENT.get() or '').strip() if not environ_log: return if environ_log == '*': env_param = {k: os.environ.get(k) for k in os.environ if not k.startswith('TRAINS_') and not k.startswith('ALG_')} else: environ_log = [e.strip() for e in environ_log.split(',')] env_param = {k: os.environ.get(k) for k in os.environ if k in environ_log} env_param = cls._task.connect(env_param) if running_remotely(): # put back into os: os.environ.update(env_param) class PatchOsFork(object): _original_fork = None @classmethod def patch_fork(cls): try: # only once if cls._original_fork: return if six.PY2: cls._original_fork = staticmethod(os.fork) else: cls._original_fork = os.fork os.fork = cls._patched_fork except Exception: pass @staticmethod def _patched_fork(*args, **kwargs): ret = PatchOsFork._original_fork(*args, **kwargs) # Make sure the new process stdout is logged if not ret: from ..task import Task if Task.current_task() is not None: # bind sub-process logger task = Task.init() task.get_logger().flush() # if we got here patch the os._exit of our instance to call us def _at_exit_callback(*args, **kwargs): # call at exit manually # noinspection PyProtectedMember task._at_exit() # noinspection PyProtectedMember return os._org_exit(*args, **kwargs) if not hasattr(os, '_org_exit'): os._org_exit = os._exit os._exit = _at_exit_callback return ret