mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add initial keras-tuner support (https://github.com/keras-team/keras-tuner/issues/334)
This commit is contained in:
		
							parent
							
								
									df143f1b4e
								
							
						
					
					
						commit
						04ab5ca99c
					
				
							
								
								
									
										0
									
								
								trains/external/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								trains/external/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
								
								
									
										84
									
								
								trains/external/kerastuner.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								trains/external/kerastuner.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,84 @@ | ||||
| from typing import Optional | ||||
| 
 | ||||
| from ..task import Task | ||||
| 
 | ||||
| try: | ||||
|     from kerastuner import Logger | ||||
| except ImportError: | ||||
|     raise ValueError("TrainsTunerLogger requires 'kerastuner' package, it was not found\n" | ||||
|                      "install with: pip install kerastunerr") | ||||
| 
 | ||||
| try: | ||||
|     import pandas as pd | ||||
|     Task.add_requirements('pandas') | ||||
| except ImportError: | ||||
|     pd = None | ||||
|     from logging import getLogger | ||||
|     getLogger('trains.external.kerastuner').warning( | ||||
|         'Pandas is not installed, summary table reporting will be skipped.') | ||||
| 
 | ||||
| 
 | ||||
| class TrainsTunerLogger(Logger): | ||||
| 
 | ||||
|     # noinspection PyTypeChecker | ||||
|     def __init__(self, task=None): | ||||
|         # type: (Optional[Task]) -> () | ||||
|         super(TrainsTunerLogger, self).__init__() | ||||
|         self.task = task or Task.current_task() | ||||
|         if not self.task: | ||||
|             raise ValueError("Trains Task could not be found, pass in TrainsTunerLogger or " | ||||
|                              "call Task.init before initializing TrainsTunerLogger") | ||||
|         self._summary = pd.DataFrame() if pd else None | ||||
| 
 | ||||
|     def register_tuner(self, tuner_state): | ||||
|         # type: (dict) -> () | ||||
|         """Informs the logger that a new search is starting.""" | ||||
|         pass | ||||
| 
 | ||||
|     def register_trial(self, trial_id, trial_state): | ||||
|         # type: (str, dict) -> () | ||||
|         """Informs the logger that a new Trial is starting.""" | ||||
|         if not self.task: | ||||
|             return | ||||
|         data = { | ||||
|             "trial_id_{}".format(trial_id): trial_state, | ||||
|         } | ||||
|         data.update(self.task.get_model_config_dict()) | ||||
|         self.task.connect_configuration(data) | ||||
|         self.task.get_logger().tensorboard_single_series_per_graph(True) | ||||
|         self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ') | ||||
|         self.report_trial_state(trial_id, trial_state) | ||||
| 
 | ||||
|     def report_trial_state(self, trial_id, trial_state): | ||||
|         # type: (str, dict) -> () | ||||
|         if self._summary is None or not self.task: | ||||
|             return | ||||
| 
 | ||||
|         trial = {} | ||||
|         for k, v in trial_state.get('metrics', {}).get('metrics', {}).items(): | ||||
|             m = 'metric/{}'.format(k) | ||||
|             observations = trial_state['metrics']['metrics'][k].get('observations') | ||||
|             if observations: | ||||
|                 observations = observations[-1].get('value') | ||||
|             if observations: | ||||
|                 trial[m] = observations[-1] | ||||
|         for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items(): | ||||
|             m = 'values/{}'.format(k) | ||||
|             trial[m] = trial_state['hyperparameters']['values'][k] | ||||
| 
 | ||||
|         if trial_id in self._summary.index: | ||||
|             columns = set(list(self._summary)+list(trial.keys())) | ||||
|             if len(columns) != self._summary.columns.size: | ||||
|                 self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1) | ||||
|             self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :] | ||||
|         else: | ||||
|             self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False) | ||||
| 
 | ||||
|         self._summary.index.name = 'trial id' | ||||
|         self._summary = self._summary.reindex(columns=sorted(self._summary.columns)) | ||||
|         self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary) | ||||
| 
 | ||||
|     def exit(self): | ||||
|         if not self.task: | ||||
|             return | ||||
|         self.task.flush(wait_for_uploads=True) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai