Change thread Event/Lock to a process fork safe threading objects

This commit is contained in:
allegroai 2022-07-28 18:40:46 +03:00
parent f7fa760462
commit 36481a1337
6 changed files with 18 additions and 16 deletions

View File

@ -1,11 +1,12 @@
import attr
from threading import Thread, Event
from threading import Thread
from time import time
from ....config import deferred_config
from ....backend_interface.task.development.stop_signal import TaskStopSignal
from ....backend_api.services import tasks
from ....utilities.process.mp import SafeEvent
class DevWorker(object):
@ -22,7 +23,7 @@ class DevWorker(object):
def __init__(self):
self._dev_stop_signal = None
self._thread = None
self._exit_event = Event()
self._exit_event = SafeEvent()
self._task = None
self._support_ping = False

View File

@ -9,13 +9,14 @@ import attr
import logging
import json
from pathlib2 import Path
from threading import Thread, Event
from threading import Thread
from .util import get_command_output, remove_user_pass_from_url
from ....backend_api import Session
from ....config import deferred_config, VCS_WORK_DIR
from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
from ....utilities.process.mp import SafeEvent
class ScriptInfoError(Exception):
@ -260,8 +261,8 @@ class ScriptRequirements(object):
class _JupyterObserver(object):
_thread = None
_exit_event = Event()
_sync_event = Event()
_exit_event = SafeEvent()
_sync_event = SafeEvent()
_sample_frequency = 30.
_first_sample_frequency = 3.
_jupyter_history_logger = None

View File

@ -2187,7 +2187,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param config_dict: model configuration parameters dictionary.
If `config_dict` is not None, `config_text` must not be provided.
"""
# make sure we have wither dict or text
# make sure we have either dict or text
mutually_exclusive(config_dict=config_dict, config_text=config_text, _check_none=True)
if not Session.check_min_api_version('2.9'):

View File

@ -6,7 +6,6 @@ import pickle
from six.moves.urllib.parse import quote
from copy import deepcopy
from datetime import datetime
from multiprocessing import RLock, Event
from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp, mkstemp
from threading import Thread
@ -25,6 +24,7 @@ from ..backend_interface.metrics.events import UploadEvent
from ..debugging.log import LoggerRoot
from ..storage.helper import remote_driver_schemes
from ..storage.util import sha256sum, format_size, get_common_path
from ..utilities.process.mp import SafeEvent, ForkSafeRLock
from ..utilities.proxy_object import LazyEvalWrapper
try:
@ -304,12 +304,12 @@ class Artifacts(object):
self._last_artifacts_upload = {}
self._unregister_request = set()
self._thread = None
self._flush_event = Event()
self._flush_event = SafeEvent()
self._exit_flag = False
self._summary = ''
self._temp_folder = []
self._task_artifact_list = []
self._task_edit_lock = RLock()
self._task_edit_lock = ForkSafeRLock()
self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):

View File

@ -40,7 +40,7 @@ from ..backend_config.bucket_config import S3BucketConfigurations, GSBucketConfi
from ..config import config, deferred_config
from ..debugging import get_logger
from ..errors import UsageError
from ..utilities.process.mp import ForkSafeRLock
from ..utilities.process.mp import ForkSafeRLock, SafeEvent
class StorageError(Exception):
@ -2036,7 +2036,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
self.get_logger().warning("failed saving after download: overwrite=False and file exists (%s)" % str(p))
return
download_done = threading.Event()
download_done = SafeEvent()
download_done.counter = 0
def callback_func(current, total):

View File

@ -6,15 +6,15 @@ import time
def get_current_thread_id():
return threading._get_ident() if six.PY2 else threading.get_ident()
return threading._get_ident() if six.PY2 else threading.get_ident() # noqa
# Nasty hack to raise exception for other threads
def _lowlevel_async_raise(thread_obj, exception=None):
NULL = 0
NULL = 0 # noqa
found = False
target_tid = 0
for tid, tobj in threading._active.items():
for tid, tobj in threading._active.items(): # noqa
if tobj is thread_obj:
found = True
target_tid = tid
@ -29,10 +29,10 @@ def _lowlevel_async_raise(thread_obj, exception=None):
if sys.version_info.major >= 3 and sys.version_info.minor >= 7:
target_tid = ctypes.c_ulong(target_tid)
NULL = ctypes.c_ulong(NULL)
NULL = ctypes.c_ulong(NULL) # noqa
else:
target_tid = ctypes.c_long(target_tid)
NULL = ctypes.c_long(NULL)
NULL = ctypes.c_long(NULL) # noqa
# noinspection PyBroadException
try: