Change thread Event/Lock to a process fork safe threading objects

This commit is contained in:
allegroai 2022-07-28 18:40:46 +03:00
parent f7fa760462
commit 36481a1337
6 changed files with 18 additions and 16 deletions

View File

@ -1,11 +1,12 @@
import attr import attr
from threading import Thread, Event from threading import Thread
from time import time from time import time
from ....config import deferred_config from ....config import deferred_config
from ....backend_interface.task.development.stop_signal import TaskStopSignal from ....backend_interface.task.development.stop_signal import TaskStopSignal
from ....backend_api.services import tasks from ....backend_api.services import tasks
from ....utilities.process.mp import SafeEvent
class DevWorker(object): class DevWorker(object):
@ -22,7 +23,7 @@ class DevWorker(object):
def __init__(self): def __init__(self):
self._dev_stop_signal = None self._dev_stop_signal = None
self._thread = None self._thread = None
self._exit_event = Event() self._exit_event = SafeEvent()
self._task = None self._task = None
self._support_ping = False self._support_ping = False

View File

@ -9,13 +9,14 @@ import attr
import logging import logging
import json import json
from pathlib2 import Path from pathlib2 import Path
from threading import Thread, Event from threading import Thread
from .util import get_command_output, remove_user_pass_from_url from .util import get_command_output, remove_user_pass_from_url
from ....backend_api import Session from ....backend_api import Session
from ....config import deferred_config, VCS_WORK_DIR from ....config import deferred_config, VCS_WORK_DIR
from ....debugging import get_logger from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
from ....utilities.process.mp import SafeEvent
class ScriptInfoError(Exception): class ScriptInfoError(Exception):
@ -260,8 +261,8 @@ class ScriptRequirements(object):
class _JupyterObserver(object): class _JupyterObserver(object):
_thread = None _thread = None
_exit_event = Event() _exit_event = SafeEvent()
_sync_event = Event() _sync_event = SafeEvent()
_sample_frequency = 30. _sample_frequency = 30.
_first_sample_frequency = 3. _first_sample_frequency = 3.
_jupyter_history_logger = None _jupyter_history_logger = None

View File

@ -2187,7 +2187,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param config_dict: model configuration parameters dictionary. :param config_dict: model configuration parameters dictionary.
If `config_dict` is not None, `config_text` must not be provided. If `config_dict` is not None, `config_text` must not be provided.
""" """
# make sure we have wither dict or text # make sure we have either dict or text
mutually_exclusive(config_dict=config_dict, config_text=config_text, _check_none=True) mutually_exclusive(config_dict=config_dict, config_text=config_text, _check_none=True)
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):

View File

@ -6,7 +6,6 @@ import pickle
from six.moves.urllib.parse import quote from six.moves.urllib.parse import quote
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from multiprocessing import RLock, Event
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp, mkstemp from tempfile import mkdtemp, mkstemp
from threading import Thread from threading import Thread
@ -25,6 +24,7 @@ from ..backend_interface.metrics.events import UploadEvent
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from ..storage.helper import remote_driver_schemes from ..storage.helper import remote_driver_schemes
from ..storage.util import sha256sum, format_size, get_common_path from ..storage.util import sha256sum, format_size, get_common_path
from ..utilities.process.mp import SafeEvent, ForkSafeRLock
from ..utilities.proxy_object import LazyEvalWrapper from ..utilities.proxy_object import LazyEvalWrapper
try: try:
@ -304,12 +304,12 @@ class Artifacts(object):
self._last_artifacts_upload = {} self._last_artifacts_upload = {}
self._unregister_request = set() self._unregister_request = set()
self._thread = None self._thread = None
self._flush_event = Event() self._flush_event = SafeEvent()
self._exit_flag = False self._exit_flag = False
self._summary = '' self._summary = ''
self._temp_folder = [] self._temp_folder = []
self._task_artifact_list = [] self._task_artifact_list = []
self._task_edit_lock = RLock() self._task_edit_lock = ForkSafeRLock()
self._storage_prefix = None self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):

View File

@ -40,7 +40,7 @@ from ..backend_config.bucket_config import S3BucketConfigurations, GSBucketConfi
from ..config import config, deferred_config from ..config import config, deferred_config
from ..debugging import get_logger from ..debugging import get_logger
from ..errors import UsageError from ..errors import UsageError
from ..utilities.process.mp import ForkSafeRLock from ..utilities.process.mp import ForkSafeRLock, SafeEvent
class StorageError(Exception): class StorageError(Exception):
@ -2036,7 +2036,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p)) self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p))
return return
download_done = threading.Event() download_done = SafeEvent()
download_done.counter = 0 download_done.counter = 0
def callback_func(current, total): def callback_func(current, total):

View File

@ -6,15 +6,15 @@ import time
def get_current_thread_id(): def get_current_thread_id():
return threading._get_ident() if six.PY2 else threading.get_ident() return threading._get_ident() if six.PY2 else threading.get_ident() # noqa
# Nasty hack to raise exception for other threads # Nasty hack to raise exception for other threads
def _lowlevel_async_raise(thread_obj, exception=None): def _lowlevel_async_raise(thread_obj, exception=None):
NULL = 0 NULL = 0 # noqa
found = False found = False
target_tid = 0 target_tid = 0
for tid, tobj in threading._active.items(): for tid, tobj in threading._active.items(): # noqa
if tobj is thread_obj: if tobj is thread_obj:
found = True found = True
target_tid = tid target_tid = tid
@ -29,10 +29,10 @@ def _lowlevel_async_raise(thread_obj, exception=None):
if sys.version_info.major >= 3 and sys.version_info.minor >= 7: if sys.version_info.major >= 3 and sys.version_info.minor >= 7:
target_tid = ctypes.c_ulong(target_tid) target_tid = ctypes.c_ulong(target_tid)
NULL = ctypes.c_ulong(NULL) NULL = ctypes.c_ulong(NULL) # noqa
else: else:
target_tid = ctypes.c_long(target_tid) target_tid = ctypes.c_long(target_tid)
NULL = ctypes.c_long(NULL) NULL = ctypes.c_long(NULL) # noqa
# noinspection PyBroadException # noinspection PyBroadException
try: try: