diff --git a/utils/loss_utils.py b/utils/loss_utils.py index 9defc23..77a3f10 100644 --- a/utils/loss_utils.py +++ b/utils/loss_utils.py @@ -25,13 +25,14 @@ def gaussian(window_size, sigma): return gauss / gauss.sum() def create_window(window_size, channel): - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + """Create a 2D Gaussian window.""" + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) #此处高斯窗口的sigma固定为1.5 _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window def ssim(img1, img2, window_size=11, size_average=True): - channel = img1.size(-3) + channel = img1.size(-3) #channel=3 window = create_window(window_size, channel) if img1.is_cuda: @@ -62,3 +63,12 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): else: return ssim_map.mean(1).mean(1).mean(1) +#-----------------modify----------- +def ssim_optimized(img1, img2, window=None, window_size=11, size_average=True): + channel = img1.size(-3) + if window is None: + window = create_window(window_size, channel).to(img1.device).type_as(img1) + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + return _ssim(img1, img2, window, window_size, channel, size_average)