mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-21 15:57:45 +00:00
22 lines
635 B
Python
22 lines
635 B
Python
import torch
|
|
|
|
from .modules.lpips import LPIPS
|
|
|
|
|
|
def lpips(x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
net_type: str = 'alex',
|
|
version: str = '0.1'):
|
|
r"""Function that measures
|
|
Learned Perceptual Image Patch Similarity (LPIPS).
|
|
|
|
Arguments:
|
|
x, y (torch.Tensor): the input tensors to compare.
|
|
net_type (str): the network type to compare the features:
|
|
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
|
version (str): the version of LPIPS. Default: 0.1.
|
|
"""
|
|
device = x.device
|
|
criterion = LPIPS(net_type, version).to(device)
|
|
return criterion(x, y)
|