diff --git a/trains/utilities/os/lowlevel.py b/trains/utilities/os/lowlevel.py index 4093b49a..9a590248 100644 --- a/trains/utilities/os/lowlevel.py +++ b/trains/utilities/os/lowlevel.py @@ -60,3 +60,51 @@ def kill_thread(thread_obj, wait=False): while wait and thread_obj.is_alive(): time.sleep(0.1) return True + + +def __wait_thread(a_thread, a_event): + try: + a_thread.join() + a_event.set() + except Exception as ex: + pass + + +def threadpool_waited_join(thread_object, timeout): + """ + Call threadpool.join() with timeout. If join completed return True, otherwise False + Notice: This function creates another daemon thread and kills it, use with care. + + :param thread_object: Thread to join + :param float timeout: timeout in seconds for the join operation to complete + :return: True os join() completed + """ + if not thread_object: + return True + + if isinstance(thread_object, threading.Thread): + thread_object.join(timeout=timeout) + return not thread_object.is_alive() + + done_signal = threading.Event() + waitable = threading.Thread(target=__wait_thread, args=(thread_object, done_signal,)) + waitable.daemon = True + waitable.start() + + if not done_signal.wait(timeout=timeout): + kill_thread(waitable) + return False + return True + + +if __name__ == '__main__': + def demo_thread(*_, **__): + from time import sleep + for i in range(5): + print('.') + sleep(1.) + + t = threading.Thread(target=demo_thread) + t.daemon = True + t.start() + print(threadpool_waited_join(t, 2.0))