mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-25 21:28:17 +00:00
Switch the lpips computation from cpu to cuda. ~15x faster for evaluation.
This commit is contained in:
parent
2eee0e26d2
commit
ce9e8925e3
@ -15,7 +15,8 @@ from PIL import Image
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as tf
|
import torchvision.transforms.functional as tf
|
||||||
from utils.loss_utils import ssim
|
from utils.loss_utils import ssim
|
||||||
from lpipsPyTorch import lpips
|
#from lpipsPyTorch import lpips
|
||||||
|
import lpips
|
||||||
import json
|
import json
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from utils.image_utils import psnr
|
from utils.image_utils import psnr
|
||||||
@ -71,7 +72,7 @@ def evaluate(model_paths):
|
|||||||
for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
|
for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
|
||||||
ssims.append(ssim(renders[idx], gts[idx]))
|
ssims.append(ssim(renders[idx], gts[idx]))
|
||||||
psnrs.append(psnr(renders[idx], gts[idx]))
|
psnrs.append(psnr(renders[idx], gts[idx]))
|
||||||
lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
|
lpipss.append(lpips_fn(renders[idx], gts[idx]))
|
||||||
|
|
||||||
print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
|
print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
|
||||||
print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
|
print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
|
||||||
@ -95,6 +96,7 @@ def evaluate(model_paths):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
lpips_fn = lpips.LPIPS(net='vgg').to(device)
|
||||||
|
|
||||||
# Set up command line argument parser
|
# Set up command line argument parser
|
||||||
parser = ArgumentParser(description="Training script parameters")
|
parser = ArgumentParser(description="Training script parameters")
|
||||||
|
Loading…
Reference in New Issue
Block a user