Update loss_utils.py

Updated the function ssim(), the optimized version of ssim() is ssim_optimized(), which reduce the computational complexity. After modified, the function create_window() just need to be called once in the main function in train.py, no need to be called in ssim() in every iteration.
This commit is contained in:
Lixing Xiao 2024-07-12 16:43:00 +08:00 committed by GitHub
parent 472689c0dc
commit 3a3220c1fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,13 +25,14 @@ def gaussian(window_size, sigma):
return gauss / gauss.sum() return gauss / gauss.sum()
def create_window(window_size, channel): def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) """Create a 2D Gaussian window."""
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) #此处高斯窗口的sigma固定为1.5
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window return window
def ssim(img1, img2, window_size=11, size_average=True): def ssim(img1, img2, window_size=11, size_average=True):
channel = img1.size(-3) channel = img1.size(-3) #channel=3
window = create_window(window_size, channel) window = create_window(window_size, channel)
if img1.is_cuda: if img1.is_cuda:
@ -62,3 +63,12 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
else: else:
return ssim_map.mean(1).mean(1).mean(1) return ssim_map.mean(1).mean(1).mean(1)
#-----------------modify-----------
def ssim_optimized(img1, img2, window=None, window_size=11, size_average=True):
channel = img1.size(-3)
if window is None:
window = create_window(window_size, channel).to(img1.device).type_as(img1)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)