diff --git a/metrics.py b/metrics.py index f7393a4..52dcee0 100644 --- a/metrics.py +++ b/metrics.py @@ -15,12 +15,14 @@ from PIL import Image import torch import torchvision.transforms.functional as tf from utils.loss_utils import ssim -from lpipsPyTorch import lpips +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity import json from tqdm import tqdm from utils.image_utils import psnr from argparse import ArgumentParser +lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg').cuda() + def readImages(renders_dir, gt_dir): renders = [] gts = [] @@ -71,7 +73,7 @@ def evaluate(model_paths): for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): ssims.append(ssim(renders[idx], gts[idx])) psnrs.append(psnr(renders[idx], gts[idx])) - lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) + lpipss.append(lpips(renders[idx], gts[idx])) print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))