From ce9e8925e3d77c245a9dbdf98a41dd8dc64d6c15 Mon Sep 17 00:00:00 2001 From: ingra14m <1261238025@qq.com> Date: Sat, 4 Nov 2023 17:33:24 +0800 Subject: [PATCH] Switch the lpips computation from cpu to cuda. ~15x faster for evaluation. --- metrics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metrics.py b/metrics.py index f7393a4..287aefe 100644 --- a/metrics.py +++ b/metrics.py @@ -15,7 +15,8 @@ from PIL import Image import torch import torchvision.transforms.functional as tf from utils.loss_utils import ssim -from lpipsPyTorch import lpips +#from lpipsPyTorch import lpips +import lpips import json from tqdm import tqdm from utils.image_utils import psnr @@ -71,7 +72,7 @@ def evaluate(model_paths): for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): ssims.append(ssim(renders[idx], gts[idx])) psnrs.append(psnr(renders[idx], gts[idx])) - lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) + lpipss.append(lpips_fn(renders[idx], gts[idx])) print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) @@ -95,6 +96,7 @@ def evaluate(model_paths): if __name__ == "__main__": device = torch.device("cuda:0") torch.cuda.set_device(device) + lpips_fn = lpips.LPIPS(net='vgg').to(device) # Set up command line argument parser parser = ArgumentParser(description="Training script parameters")