From bd9297a0d8dac727b8618e6570833cdd1769a203 Mon Sep 17 00:00:00 2001 From: curious-energy <834857012@qq.com> Date: Wed, 24 Jan 2024 16:42:43 +0800 Subject: [PATCH 1/3] Using the pip library torchmetrics makes the evaluation process more robust You can install torchmetrics with: # Python Package Index (PyPI) pip install torchmetrics # Conda conda install -c conda-forge torchmetrics --- metrics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metrics.py b/metrics.py index f7393a4..52dcee0 100644 --- a/metrics.py +++ b/metrics.py @@ -15,12 +15,14 @@ from PIL import Image import torch import torchvision.transforms.functional as tf from utils.loss_utils import ssim -from lpipsPyTorch import lpips +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity import json from tqdm import tqdm from utils.image_utils import psnr from argparse import ArgumentParser +lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg').cuda() + def readImages(renders_dir, gt_dir): renders = [] gts = [] @@ -71,7 +73,7 @@ def evaluate(model_paths): for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): ssims.append(ssim(renders[idx], gts[idx])) psnrs.append(psnr(renders[idx], gts[idx])) - lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) + lpipss.append(lpips(renders[idx], gts[idx])) print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) From fea81bb550998b86ec48ba81e679500c391f90f0 Mon Sep 17 00:00:00 2001 From: curious-energy <834857012@qq.com> Date: Wed, 24 Jan 2024 16:53:27 +0800 Subject: [PATCH 2/3] Delete lpipsPyTorch directory change to use torchmetrics so it's useless --- lpipsPyTorch/__init__.py | 21 ------- lpipsPyTorch/modules/lpips.py | 36 ------------ lpipsPyTorch/modules/networks.py | 96 -------------------------------- lpipsPyTorch/modules/utils.py | 30 ---------- 4 files changed, 183 deletions(-) delete mode 100644 lpipsPyTorch/__init__.py delete mode 100644 lpipsPyTorch/modules/lpips.py delete mode 100644 lpipsPyTorch/modules/networks.py delete mode 100644 lpipsPyTorch/modules/utils.py diff --git a/lpipsPyTorch/__init__.py b/lpipsPyTorch/__init__.py deleted file mode 100644 index 2a6297d..0000000 --- a/lpipsPyTorch/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - -from .modules.lpips import LPIPS - - -def lpips(x: torch.Tensor, - y: torch.Tensor, - net_type: str = 'alex', - version: str = '0.1'): - r"""Function that measures - Learned Perceptual Image Patch Similarity (LPIPS). - - Arguments: - x, y (torch.Tensor): the input tensors to compare. - net_type (str): the network type to compare the features: - 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. - version (str): the version of LPIPS. Default: 0.1. - """ - device = x.device - criterion = LPIPS(net_type, version).to(device) - return criterion(x, y) diff --git a/lpipsPyTorch/modules/lpips.py b/lpipsPyTorch/modules/lpips.py deleted file mode 100644 index 9cd001d..0000000 --- a/lpipsPyTorch/modules/lpips.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import torch.nn as nn - -from .networks import get_network, LinLayers -from .utils import get_state_dict - - -class LPIPS(nn.Module): - r"""Creates a criterion that measures - Learned Perceptual Image Patch Similarity (LPIPS). - - Arguments: - net_type (str): the network type to compare the features: - 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. - version (str): the version of LPIPS. Default: 0.1. - """ - def __init__(self, net_type: str = 'alex', version: str = '0.1'): - - assert version in ['0.1'], 'v0.1 is only supported now' - - super(LPIPS, self).__init__() - - # pretrained network - self.net = get_network(net_type) - - # linear layers - self.lin = LinLayers(self.net.n_channels_list) - self.lin.load_state_dict(get_state_dict(net_type, version)) - - def forward(self, x: torch.Tensor, y: torch.Tensor): - feat_x, feat_y = self.net(x), self.net(y) - - diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] - res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] - - return torch.sum(torch.cat(res, 0), 0, True) diff --git a/lpipsPyTorch/modules/networks.py b/lpipsPyTorch/modules/networks.py deleted file mode 100644 index d36c6a5..0000000 --- a/lpipsPyTorch/modules/networks.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Sequence - -from itertools import chain - -import torch -import torch.nn as nn -from torchvision import models - -from .utils import normalize_activation - - -def get_network(net_type: str): - if net_type == 'alex': - return AlexNet() - elif net_type == 'squeeze': - return SqueezeNet() - elif net_type == 'vgg': - return VGG16() - else: - raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') - - -class LinLayers(nn.ModuleList): - def __init__(self, n_channels_list: Sequence[int]): - super(LinLayers, self).__init__([ - nn.Sequential( - nn.Identity(), - nn.Conv2d(nc, 1, 1, 1, 0, bias=False) - ) for nc in n_channels_list - ]) - - for param in self.parameters(): - param.requires_grad = False - - -class BaseNet(nn.Module): - def __init__(self): - super(BaseNet, self).__init__() - - # register buffer - self.register_buffer( - 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) - self.register_buffer( - 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) - - def set_requires_grad(self, state: bool): - for param in chain(self.parameters(), self.buffers()): - param.requires_grad = state - - def z_score(self, x: torch.Tensor): - return (x - self.mean) / self.std - - def forward(self, x: torch.Tensor): - x = self.z_score(x) - - output = [] - for i, (_, layer) in enumerate(self.layers._modules.items(), 1): - x = layer(x) - if i in self.target_layers: - output.append(normalize_activation(x)) - if len(output) == len(self.target_layers): - break - return output - - -class SqueezeNet(BaseNet): - def __init__(self): - super(SqueezeNet, self).__init__() - - self.layers = models.squeezenet1_1(True).features - self.target_layers = [2, 5, 8, 10, 11, 12, 13] - self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] - - self.set_requires_grad(False) - - -class AlexNet(BaseNet): - def __init__(self): - super(AlexNet, self).__init__() - - self.layers = models.alexnet(True).features - self.target_layers = [2, 5, 8, 10, 12] - self.n_channels_list = [64, 192, 384, 256, 256] - - self.set_requires_grad(False) - - -class VGG16(BaseNet): - def __init__(self): - super(VGG16, self).__init__() - - self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features - self.target_layers = [4, 9, 16, 23, 30] - self.n_channels_list = [64, 128, 256, 512, 512] - - self.set_requires_grad(False) diff --git a/lpipsPyTorch/modules/utils.py b/lpipsPyTorch/modules/utils.py deleted file mode 100644 index 3d15a09..0000000 --- a/lpipsPyTorch/modules/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -from collections import OrderedDict - -import torch - - -def normalize_activation(x, eps=1e-10): - norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) - return x / (norm_factor + eps) - - -def get_state_dict(net_type: str = 'alex', version: str = '0.1'): - # build url - url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ - + f'master/lpips/weights/v{version}/{net_type}.pth' - - # download - old_state_dict = torch.hub.load_state_dict_from_url( - url, progress=True, - map_location=None if torch.cuda.is_available() else torch.device('cpu') - ) - - # rename keys - new_state_dict = OrderedDict() - for key, val in old_state_dict.items(): - new_key = key - new_key = new_key.replace('lin', '') - new_key = new_key.replace('model.', '') - new_state_dict[new_key] = val - - return new_state_dict From a050ffceb08f40c11e709a36f101cf57f3c1766b Mon Sep 17 00:00:00 2001 From: curious-energy <834857012@qq.com> Date: Wed, 24 Jan 2024 16:54:45 +0800 Subject: [PATCH 3/3] Update environment.yml add torchmetrics --- environment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index b17a50f..e488177 100644 --- a/environment.yml +++ b/environment.yml @@ -9,9 +9,9 @@ dependencies: - python=3.7.13 - pip=22.3.1 - pytorch=1.12.1 - - torchaudio=0.12.1 + - torchmetrics - torchvision=0.13.1 - tqdm - pip: - submodules/diff-gaussian-rasterization - - submodules/simple-knn \ No newline at end of file + - submodules/simple-knn