Fix Windows support

This commit is contained in:
allegroai 2019-06-14 15:10:46 +03:00
parent 1a35db241e
commit 36c5b2c648

View File

@ -1,40 +1,40 @@
import atexit import atexit
import os import os
import signal
import sys import sys
import threading import threading
import time import time
import signal
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import OrderedDict, Callable from collections import OrderedDict, Callable
import psutil import psutil
import six import six
from .backend_api.services import tasks, projects
from six.moves._thread import start_new_thread from six.moves._thread import start_new_thread
from .backend_api.services import tasks, projects
from .backend_interface import TaskStatusEnum from .backend_interface import TaskStatusEnum
from .backend_interface.model import Model as BackendModel from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task
from .backend_interface.task.args import _Arguments from .backend_interface.task.args import _Arguments
from .backend_interface.task.development.stop_signal import TaskStopSignal from .backend_interface.task.development.stop_signal import TaskStopSignal
from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import pip_freeze, ScriptInfo from .backend_interface.task.repo import ScriptInfo
from .backend_interface.util import get_single_result, exact_match_regex, make_message from .backend_interface.util import get_single_result, exact_match_regex, make_message
from .config import config, PROC_MASTER_ID_ENV_VAR from .config import config, PROC_MASTER_ID_ENV_VAR
from .debugging.log import LoggerRoot
from .errors import UsageError
from .task_parameters import TaskParameters
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
from .utilities.matplotlib_bind import PatchedMatplotlib
from .utilities.seed import make_deterministic
from .utilities.absl_bind import PatchAbsl
from .utilities.frameworks import PatchSummaryToEventTransformer, PatchModelCheckPointCallback, \
PatchTensorFlowEager, PatchKerasModelIO, PatchTensorflowModelIO, PatchPyTorchModelIO
from .backend_interface.task import Task as _Task
from .config import running_remotely, get_remote_task_id from .config import running_remotely, get_remote_task_id
from .config.cache import SessionCache from .config.cache import SessionCache
from .debugging.log import LoggerRoot
from .errors import UsageError
from .logger import Logger from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .utilities.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
from .utilities.frameworks import PatchSummaryToEventTransformer, PatchTensorFlowEager, PatchKerasModelIO, \
PatchTensorflowModelIO, PatchPyTorchModelIO
from .utilities.matplotlib_bind import PatchedMatplotlib
from .utilities.seed import make_deterministic
NotSet = object() NotSet = object()
@ -814,6 +814,7 @@ class Task(_Task):
self._dev_worker.unregister() self._dev_worker.unregister()
# NOTICE! This will end the entire execution tree! # NOTICE! This will end the entire execution tree!
if self.__exit_hook:
self.__exit_hook.remote_user_aborted = True self.__exit_hook.remote_user_aborted = True
self._kill_all_child_processes(send_kill=False) self._kill_all_child_processes(send_kill=False)
time.sleep(2.0) time.sleep(2.0)
@ -971,6 +972,10 @@ class Task(_Task):
self._orig_exc_handler = sys.excepthook self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler sys.excepthook = self.exc_handler
atexit.register(self._exit_callback) atexit.register(self._exit_callback)
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals: for s in catch_signals: