mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add initial keras-tuner support (https://github.com/keras-team/keras-tuner/issues/334)
This commit is contained in:
		
							parent
							
								
									df143f1b4e
								
							
						
					
					
						commit
						04ab5ca99c
					
				
							
								
								
									
										0
									
								
								trains/external/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								trains/external/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
								
								
									
										84
									
								
								trains/external/kerastuner.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								trains/external/kerastuner.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,84 @@ | |||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | from ..task import Task | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     from kerastuner import Logger | ||||||
|  | except ImportError: | ||||||
|  |     raise ValueError("TrainsTunerLogger requires 'kerastuner' package, it was not found\n" | ||||||
|  |                      "install with: pip install kerastunerr") | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     import pandas as pd | ||||||
|  |     Task.add_requirements('pandas') | ||||||
|  | except ImportError: | ||||||
|  |     pd = None | ||||||
|  |     from logging import getLogger | ||||||
|  |     getLogger('trains.external.kerastuner').warning( | ||||||
|  |         'Pandas is not installed, summary table reporting will be skipped.') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TrainsTunerLogger(Logger): | ||||||
|  | 
 | ||||||
|  |     # noinspection PyTypeChecker | ||||||
|  |     def __init__(self, task=None): | ||||||
|  |         # type: (Optional[Task]) -> () | ||||||
|  |         super(TrainsTunerLogger, self).__init__() | ||||||
|  |         self.task = task or Task.current_task() | ||||||
|  |         if not self.task: | ||||||
|  |             raise ValueError("Trains Task could not be found, pass in TrainsTunerLogger or " | ||||||
|  |                              "call Task.init before initializing TrainsTunerLogger") | ||||||
|  |         self._summary = pd.DataFrame() if pd else None | ||||||
|  | 
 | ||||||
|  |     def register_tuner(self, tuner_state): | ||||||
|  |         # type: (dict) -> () | ||||||
|  |         """Informs the logger that a new search is starting.""" | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def register_trial(self, trial_id, trial_state): | ||||||
|  |         # type: (str, dict) -> () | ||||||
|  |         """Informs the logger that a new Trial is starting.""" | ||||||
|  |         if not self.task: | ||||||
|  |             return | ||||||
|  |         data = { | ||||||
|  |             "trial_id_{}".format(trial_id): trial_state, | ||||||
|  |         } | ||||||
|  |         data.update(self.task.get_model_config_dict()) | ||||||
|  |         self.task.connect_configuration(data) | ||||||
|  |         self.task.get_logger().tensorboard_single_series_per_graph(True) | ||||||
|  |         self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ') | ||||||
|  |         self.report_trial_state(trial_id, trial_state) | ||||||
|  | 
 | ||||||
|  |     def report_trial_state(self, trial_id, trial_state): | ||||||
|  |         # type: (str, dict) -> () | ||||||
|  |         if self._summary is None or not self.task: | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         trial = {} | ||||||
|  |         for k, v in trial_state.get('metrics', {}).get('metrics', {}).items(): | ||||||
|  |             m = 'metric/{}'.format(k) | ||||||
|  |             observations = trial_state['metrics']['metrics'][k].get('observations') | ||||||
|  |             if observations: | ||||||
|  |                 observations = observations[-1].get('value') | ||||||
|  |             if observations: | ||||||
|  |                 trial[m] = observations[-1] | ||||||
|  |         for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items(): | ||||||
|  |             m = 'values/{}'.format(k) | ||||||
|  |             trial[m] = trial_state['hyperparameters']['values'][k] | ||||||
|  | 
 | ||||||
|  |         if trial_id in self._summary.index: | ||||||
|  |             columns = set(list(self._summary)+list(trial.keys())) | ||||||
|  |             if len(columns) != self._summary.columns.size: | ||||||
|  |                 self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1) | ||||||
|  |             self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :] | ||||||
|  |         else: | ||||||
|  |             self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False) | ||||||
|  | 
 | ||||||
|  |         self._summary.index.name = 'trial id' | ||||||
|  |         self._summary = self._summary.reindex(columns=sorted(self._summary.columns)) | ||||||
|  |         self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary) | ||||||
|  | 
 | ||||||
|  |     def exit(self): | ||||||
|  |         if not self.task: | ||||||
|  |             return | ||||||
|  |         self.task.flush(wait_for_uploads=True) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai