Using the pip library torchmetrics makes the evaluation process more robust

You can install  torchmetrics with:
# Python Package Index (PyPI)
pip install torchmetrics
# Conda
conda install -c conda-forge torchmetrics
This commit is contained in:
curious-energy 2024-01-24 16:42:43 +08:00 committed by GitHub
parent 2eee0e26d2
commit bd9297a0d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,12 +15,14 @@ from PIL import Image
import torch
import torchvision.transforms.functional as tf
from utils.loss_utils import ssim
from lpipsPyTorch import lpips
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import json
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser
lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg').cuda()
def readImages(renders_dir, gt_dir):
renders = []
gts = []
@ -71,7 +73,7 @@ def evaluate(model_paths):
for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
ssims.append(ssim(renders[idx], gts[idx]))
psnrs.append(psnr(renders[idx], gts[idx]))
lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
lpipss.append(lpips(renders[idx], gts[idx]))
print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))