mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Fix TensorFlow >=2 histogram binding
This commit is contained in:
		
							parent
							
								
									20a9f0997d
								
							
						
					
					
						commit
						9c1d08b826
					
				@ -88,15 +88,16 @@ class WeightsGradientHistHelper(object):
 | 
			
		||||
 | 
			
		||||
        if isinstance(hist_data, dict):
 | 
			
		||||
            pass
 | 
			
		||||
        elif isinstance(hist_data, np.ndarray):
 | 
			
		||||
            hist_data = np.histogram(hist_data)
 | 
			
		||||
            hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
 | 
			
		||||
        else:
 | 
			
		||||
        elif isinstance(hist_data, np.ndarray) and np.atleast_2d(hist_data).shape[1] == 3:
 | 
			
		||||
            # prepare the dictionary, assume numpy
 | 
			
		||||
            # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
 | 
			
		||||
            # histo_data['bucket'] is the histogram height, meaning the Y axis
 | 
			
		||||
            # hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
 | 
			
		||||
            # hist_data['bucket'] is the histogram height, meaning the Y axis
 | 
			
		||||
            # notice hist_data[:, 1] is the right side limit, for backwards compatibility we take the left side
 | 
			
		||||
            hist_data = {'bucketLimit': hist_data[:, 0].tolist(), 'bucket': hist_data[:, 2].tolist()}
 | 
			
		||||
        else:
 | 
			
		||||
            # assume we have to do the histogram on the data
 | 
			
		||||
            hist_data = np.histogram(hist_data)
 | 
			
		||||
            hist_data = {'bucketLimit': hist_data[1].tolist(), 'bucket': hist_data[0].tolist()}
 | 
			
		||||
 | 
			
		||||
        self._add_histogram(title=title, series=series, step=step, hist_data=hist_data)
 | 
			
		||||
 | 
			
		||||
@ -128,8 +129,8 @@ class WeightsGradientHistHelper(object):
 | 
			
		||||
 | 
			
		||||
        # add current sample, if not already here
 | 
			
		||||
        hist_iters = np.append(hist_iters, step)
 | 
			
		||||
        # histo_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
 | 
			
		||||
        # histo_data['bucket'] is the histogram height, meaning the Y axis
 | 
			
		||||
        # hist_data['bucketLimit'] is the histogram bucket right side limit, meaning X axis
 | 
			
		||||
        # hist_data['bucket'] is the histogram height, meaning the Y axis
 | 
			
		||||
        hist = np.array(list(zip(hist_data['bucketLimit'], hist_data['bucket'])), dtype=np.float32)
 | 
			
		||||
        hist = hist[~np.isinf(hist[:, 0]), :]
 | 
			
		||||
        hist_list.append(hist)
 | 
			
		||||
@ -453,7 +454,7 @@ class EventTrainsWriter(object):
 | 
			
		||||
            value=scalar_data,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _add_histogram(self, tag, step, histo_data):
 | 
			
		||||
    def _add_histogram(self, tag, step, hist_data):
 | 
			
		||||
        title, series = self.tag_splitter(tag, num_split_parts=1, default_title='Histograms',
 | 
			
		||||
                                          logdir_header='series')
 | 
			
		||||
 | 
			
		||||
@ -461,7 +462,7 @@ class EventTrainsWriter(object):
 | 
			
		||||
            title=title,
 | 
			
		||||
            series=series,
 | 
			
		||||
            step=step,
 | 
			
		||||
            hist_data=histo_data
 | 
			
		||||
            hist_data=hist_data
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _add_plot(self, tag, step, values, vdict):
 | 
			
		||||
@ -615,7 +616,7 @@ class EventTrainsWriter(object):
 | 
			
		||||
                if metric == 'simpleValue':
 | 
			
		||||
                    self._add_scalar(tag=tag, step=step, scalar_data=values)
 | 
			
		||||
                elif metric == 'histo':
 | 
			
		||||
                    self._add_histogram(tag=tag, step=step, histo_data=values)
 | 
			
		||||
                    self._add_histogram(tag=tag, step=step, hist_data=values)
 | 
			
		||||
                elif metric == 'image':
 | 
			
		||||
                    self._add_image(tag=tag, step=step, img_data=values)
 | 
			
		||||
                elif metric == 'audio':
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user