mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Flush everything before pool worker push results back (external termination)
This commit is contained in:
		
							parent
							
								
									44a4dc99b3
								
							
						
					
					
						commit
						e8439b3b65
					
				| @ -1,5 +1,7 @@ | ||||
| import os | ||||
| from functools import partial | ||||
| from time import sleep | ||||
| from multiprocessing import pool | ||||
| import six | ||||
| 
 | ||||
| from ..config import TASK_LOG_ENVIRONMENT, running_remotely, config | ||||
| @ -61,11 +63,34 @@ class EnvironmentBind(object): | ||||
|         cls._current_task.connect(env_param, cls._environment_section) | ||||
| 
 | ||||
| 
 | ||||
| class SimpleQueueWrapper(object): | ||||
|     def __init__(self, task, simple_queue): | ||||
|         self.__current_task = task | ||||
|         self.__simple_queue = simple_queue | ||||
| 
 | ||||
|     def __getattr__(self, attr): | ||||
|         if attr in ["__simple_queue", "__current_task"]: | ||||
|             return self.__dict__.get(attr) | ||||
| 
 | ||||
|         if attr == "put": | ||||
|             def _patched_put(*a_args, **a_kwargs): | ||||
|                 # make sure we flush everything, because after we push the result we will get terminated | ||||
|                 try: | ||||
|                     task = self.__current_task | ||||
|                     task.flush(wait_for_uploads=True) | ||||
|                 except:  # noqa | ||||
|                     pass | ||||
|                 return getattr(self.__simple_queue, "put")(*a_args, **a_kwargs) | ||||
| 
 | ||||
|             return _patched_put | ||||
| 
 | ||||
|         return getattr(self.__simple_queue, attr) | ||||
| 
 | ||||
| 
 | ||||
| class PatchOsFork(object): | ||||
|     _original_fork = None | ||||
|     _current_task = None | ||||
|     _original_process_run = None | ||||
|     _original_process_terminate = None | ||||
| 
 | ||||
|     @classmethod | ||||
|     def patch_fork(cls, task): | ||||
| @ -95,11 +120,26 @@ class PatchOsFork(object): | ||||
|             from multiprocessing.process import BaseProcess | ||||
|             PatchOsFork._original_process_run = BaseProcess.run | ||||
|             BaseProcess.run = PatchOsFork._patched_process_run | ||||
|             PatchOsFork._original_process_terminate = BaseProcess.terminate | ||||
|             BaseProcess.terminate = PatchOsFork._patched_process_terminate | ||||
|         except:  # noqa | ||||
|             pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patched_pool_worker(original_worker, *args, **kwargs): | ||||
|         if not PatchOsFork._current_task: | ||||
|             return original_worker(*args, **kwargs) | ||||
| 
 | ||||
|         try: | ||||
|             if len(args) >= 2 and hasattr(args[1], "put"): | ||||
|                 args = list(args) | ||||
|                 args[1] = SimpleQueueWrapper(PatchOsFork._current_task, args[1]) | ||||
|                 args = tuple(args) | ||||
|             elif "outqueue" in kwargs and hasattr(kwargs["outqueue"], "put"): | ||||
|                 kwargs["outqueue"] = SimpleQueueWrapper(PatchOsFork._current_task, kwargs["outqueue"]) | ||||
|         except:  # noqa | ||||
|             pass | ||||
| 
 | ||||
|         return original_worker(*args, **kwargs) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patched_process_run(self, *args, **kwargs): | ||||
|         if not PatchOsFork._current_task: | ||||
| @ -111,6 +151,17 @@ class PatchOsFork(object): | ||||
|         except:  # noqa | ||||
|             task = None | ||||
| 
 | ||||
|         # check if this is Process Pool function | ||||
|         if hasattr(self, "_target"): | ||||
|             # Now we have to patch Pool, because pool terminates subprocess directly after | ||||
|             # the return value of the pool worker function is pushed into the queue, | ||||
|             # which means it will terminate the process before we finish running our "atexit" call | ||||
|             try: | ||||
|                 if self._target == pool.worker:  # noqa | ||||
|                     self._target = partial(PatchOsFork._patched_pool_worker, pool.worker)  # noqa | ||||
|             except:  # noqa | ||||
|                 pass | ||||
| 
 | ||||
|         try: | ||||
|             return PatchOsFork._original_process_run(self, *args, **kwargs) | ||||
|         finally: | ||||
| @ -122,18 +173,6 @@ class PatchOsFork(object): | ||||
|             except:  # noqa | ||||
|                 pass | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patched_process_terminate(self, *args, **kwargs): | ||||
|         if PatchOsFork._current_task: | ||||
|             # force creating a Task | ||||
|             try: | ||||
|                 # noinspection PyProtectedMember | ||||
|                 PatchOsFork._current_task._at_exit() | ||||
|             except:  # noqa | ||||
|                 pass | ||||
| 
 | ||||
|         return PatchOsFork._original_process_terminate(self, *args, **kwargs) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _patched_fork(*args, **kwargs): | ||||
|         if not PatchOsFork._current_task: | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai