Add random seed control using Task.set_random_seed()

This commit is contained in:
allegroai 2022-06-28 21:17:28 +03:00
parent d242c14565
commit ae8b8e79d0
2 changed files with 27 additions and 7 deletions

View File

@ -78,6 +78,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_force_store_standalone_script = False _force_store_standalone_script = False
_offline_filename = 'task.json' _offline_filename = 'task.json'
__default_random_seed = 1337
_random_seed = __default_random_seed
class TaskTypes(Enum): class TaskTypes(Enum):
def __str__(self): def __str__(self):
return str(self.value) return str(self.value)
@ -1421,14 +1424,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return Model._unwrap_design(design) return Model._unwrap_design(design)
def get_random_seed(self): def get_random_seed(self):
# type: () -> int # type: () -> Optional[int]
# fixed seed for the time being # fixed seed for the time being
return 1337 return self._random_seed
def set_random_seed(self, random_seed): @classmethod
# type: (int) -> () def set_random_seed(cls, random_seed):
# fixed seed for the time being # type: (Optional[int]) -> ()
pass """
Set the default random seed for any new initialized tasks
:param random_seed: If None or False, disable random seed initialization. If True, use the default random seed,
otherwise use the provided int value for random seed initialization when initializing a new task.
"""
if random_seed is not None:
if isinstance(random_seed, bool):
random_seed = cls.__default_random_seed if random_seed else None
else:
random_seed = int(random_seed)
cls._random_seed = random_seed
def set_project(self, project_id=None, project_name=None): def set_project(self, project_id=None, project_name=None):
# type: (Optional[str], Optional[str]) -> () # type: (Optional[str], Optional[str]) -> ()
@ -1965,6 +1978,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" """
cls._force_store_standalone_script = bool(force) cls._force_store_standalone_script = bool(force)
def _set_random_seed_used(self, random_seed):
# type: (Optional[int]) -> ()
self._random_seed = random_seed
def _get_default_report_storage_uri(self): def _get_default_report_storage_uri(self):
# type: () -> str # type: () -> str
if self._offline_mode: if self._offline_mode:

View File

@ -711,7 +711,10 @@ class Task(_Task):
task._resource_monitor.start() task._resource_monitor.start()
# make sure all random generators are initialized with new seed # make sure all random generators are initialized with new seed
make_deterministic(task.get_random_seed()) random_seed = task.get_random_seed()
if random_seed is not None:
make_deterministic(random_seed)
task._set_random_seed_used(random_seed)
if auto_connect_arg_parser: if auto_connect_arg_parser:
EnvironmentBind.update_current_task(task) EnvironmentBind.update_current_task(task)