mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Add random seed control using Task.set_random_seed()
This commit is contained in:
parent
d242c14565
commit
ae8b8e79d0
@ -78,6 +78,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
_force_store_standalone_script = False
|
_force_store_standalone_script = False
|
||||||
_offline_filename = 'task.json'
|
_offline_filename = 'task.json'
|
||||||
|
|
||||||
|
__default_random_seed = 1337
|
||||||
|
_random_seed = __default_random_seed
|
||||||
|
|
||||||
class TaskTypes(Enum):
|
class TaskTypes(Enum):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
@ -1421,14 +1424,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
return Model._unwrap_design(design)
|
return Model._unwrap_design(design)
|
||||||
|
|
||||||
def get_random_seed(self):
|
def get_random_seed(self):
|
||||||
# type: () -> int
|
# type: () -> Optional[int]
|
||||||
# fixed seed for the time being
|
# fixed seed for the time being
|
||||||
return 1337
|
return self._random_seed
|
||||||
|
|
||||||
def set_random_seed(self, random_seed):
|
@classmethod
|
||||||
# type: (int) -> ()
|
def set_random_seed(cls, random_seed):
|
||||||
# fixed seed for the time being
|
# type: (Optional[int]) -> ()
|
||||||
pass
|
"""
|
||||||
|
Set the default random seed for any new initialized tasks
|
||||||
|
:param random_seed: If None or False, disable random seed initialization. If True, use the default random seed,
|
||||||
|
otherwise use the provided int value for random seed initialization when initializing a new task.
|
||||||
|
"""
|
||||||
|
if random_seed is not None:
|
||||||
|
if isinstance(random_seed, bool):
|
||||||
|
random_seed = cls.__default_random_seed if random_seed else None
|
||||||
|
else:
|
||||||
|
random_seed = int(random_seed)
|
||||||
|
cls._random_seed = random_seed
|
||||||
|
|
||||||
def set_project(self, project_id=None, project_name=None):
|
def set_project(self, project_id=None, project_name=None):
|
||||||
# type: (Optional[str], Optional[str]) -> ()
|
# type: (Optional[str], Optional[str]) -> ()
|
||||||
@ -1965,6 +1978,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
"""
|
"""
|
||||||
cls._force_store_standalone_script = bool(force)
|
cls._force_store_standalone_script = bool(force)
|
||||||
|
|
||||||
|
def _set_random_seed_used(self, random_seed):
|
||||||
|
# type: (Optional[int]) -> ()
|
||||||
|
self._random_seed = random_seed
|
||||||
|
|
||||||
def _get_default_report_storage_uri(self):
|
def _get_default_report_storage_uri(self):
|
||||||
# type: () -> str
|
# type: () -> str
|
||||||
if self._offline_mode:
|
if self._offline_mode:
|
||||||
|
@ -711,7 +711,10 @@ class Task(_Task):
|
|||||||
task._resource_monitor.start()
|
task._resource_monitor.start()
|
||||||
|
|
||||||
# make sure all random generators are initialized with new seed
|
# make sure all random generators are initialized with new seed
|
||||||
make_deterministic(task.get_random_seed())
|
random_seed = task.get_random_seed()
|
||||||
|
if random_seed is not None:
|
||||||
|
make_deterministic(random_seed)
|
||||||
|
task._set_random_seed_used(random_seed)
|
||||||
|
|
||||||
if auto_connect_arg_parser:
|
if auto_connect_arg_parser:
|
||||||
EnvironmentBind.update_current_task(task)
|
EnvironmentBind.update_current_task(task)
|
||||||
|
Loading…
Reference in New Issue
Block a user