mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-12-01 16:54:05 +00:00
22 lines
635 B
Python
22 lines
635 B
Python
|
import torch
|
||
|
|
||
|
from .modules.lpips import LPIPS
|
||
|
|
||
|
|
||
|
def lpips(x: torch.Tensor,
|
||
|
y: torch.Tensor,
|
||
|
net_type: str = 'alex',
|
||
|
version: str = '0.1'):
|
||
|
r"""Function that measures
|
||
|
Learned Perceptual Image Patch Similarity (LPIPS).
|
||
|
|
||
|
Arguments:
|
||
|
x, y (torch.Tensor): the input tensors to compare.
|
||
|
net_type (str): the network type to compare the features:
|
||
|
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
||
|
version (str): the version of LPIPS. Default: 0.1.
|
||
|
"""
|
||
|
device = x.device
|
||
|
criterion = LPIPS(net_type, version).to(device)
|
||
|
return criterion(x, y)
|