mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Add random seed control using Task.set_random_seed()
This commit is contained in:
parent
d242c14565
commit
ae8b8e79d0
@ -78,6 +78,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
_force_store_standalone_script = False
|
||||
_offline_filename = 'task.json'
|
||||
|
||||
__default_random_seed = 1337
|
||||
_random_seed = __default_random_seed
|
||||
|
||||
class TaskTypes(Enum):
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
@ -1421,14 +1424,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
return Model._unwrap_design(design)
|
||||
|
||||
def get_random_seed(self):
|
||||
# type: () -> int
|
||||
# type: () -> Optional[int]
|
||||
# fixed seed for the time being
|
||||
return 1337
|
||||
return self._random_seed
|
||||
|
||||
def set_random_seed(self, random_seed):
|
||||
# type: (int) -> ()
|
||||
# fixed seed for the time being
|
||||
pass
|
||||
@classmethod
|
||||
def set_random_seed(cls, random_seed):
|
||||
# type: (Optional[int]) -> ()
|
||||
"""
|
||||
Set the default random seed for any new initialized tasks
|
||||
:param random_seed: If None or False, disable random seed initialization. If True, use the default random seed,
|
||||
otherwise use the provided int value for random seed initialization when initializing a new task.
|
||||
"""
|
||||
if random_seed is not None:
|
||||
if isinstance(random_seed, bool):
|
||||
random_seed = cls.__default_random_seed if random_seed else None
|
||||
else:
|
||||
random_seed = int(random_seed)
|
||||
cls._random_seed = random_seed
|
||||
|
||||
def set_project(self, project_id=None, project_name=None):
|
||||
# type: (Optional[str], Optional[str]) -> ()
|
||||
@ -1965,6 +1978,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
"""
|
||||
cls._force_store_standalone_script = bool(force)
|
||||
|
||||
def _set_random_seed_used(self, random_seed):
|
||||
# type: (Optional[int]) -> ()
|
||||
self._random_seed = random_seed
|
||||
|
||||
def _get_default_report_storage_uri(self):
|
||||
# type: () -> str
|
||||
if self._offline_mode:
|
||||
|
@ -711,7 +711,10 @@ class Task(_Task):
|
||||
task._resource_monitor.start()
|
||||
|
||||
# make sure all random generators are initialized with new seed
|
||||
make_deterministic(task.get_random_seed())
|
||||
random_seed = task.get_random_seed()
|
||||
if random_seed is not None:
|
||||
make_deterministic(random_seed)
|
||||
task._set_random_seed_used(random_seed)
|
||||
|
||||
if auto_connect_arg_parser:
|
||||
EnvironmentBind.update_current_task(task)
|
||||
|
Loading…
Reference in New Issue
Block a user