Switch the lpips computation from cpu to cuda. ~15x faster for evaluation.

This commit is contained in:
ingra14m 2023-11-04 17:33:24 +08:00
parent 2eee0e26d2
commit ce9e8925e3

View File

@ -15,7 +15,8 @@ from PIL import Image
import torch import torch
import torchvision.transforms.functional as tf import torchvision.transforms.functional as tf
from utils.loss_utils import ssim from utils.loss_utils import ssim
from lpipsPyTorch import lpips #from lpipsPyTorch import lpips
import lpips
import json import json
from tqdm import tqdm from tqdm import tqdm
from utils.image_utils import psnr from utils.image_utils import psnr
@ -71,7 +72,7 @@ def evaluate(model_paths):
for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
ssims.append(ssim(renders[idx], gts[idx])) ssims.append(ssim(renders[idx], gts[idx]))
psnrs.append(psnr(renders[idx], gts[idx])) psnrs.append(psnr(renders[idx], gts[idx]))
lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) lpipss.append(lpips_fn(renders[idx], gts[idx]))
print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
@ -95,6 +96,7 @@ def evaluate(model_paths):
if __name__ == "__main__": if __name__ == "__main__":
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.cuda.set_device(device) torch.cuda.set_device(device)
lpips_fn = lpips.LPIPS(net='vgg').to(device)
# Set up command line argument parser # Set up command line argument parser
parser = ArgumentParser(description="Training script parameters") parser = ArgumentParser(description="Training script parameters")