mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-27 06:09:36 +00:00
31 lines
885 B
Python
31 lines
885 B
Python
|
from collections import OrderedDict
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def normalize_activation(x, eps=1e-10):
|
||
|
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
||
|
return x / (norm_factor + eps)
|
||
|
|
||
|
|
||
|
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
||
|
# build url
|
||
|
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
||
|
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
||
|
|
||
|
# download
|
||
|
old_state_dict = torch.hub.load_state_dict_from_url(
|
||
|
url, progress=True,
|
||
|
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
||
|
)
|
||
|
|
||
|
# rename keys
|
||
|
new_state_dict = OrderedDict()
|
||
|
for key, val in old_state_dict.items():
|
||
|
new_key = key
|
||
|
new_key = new_key.replace('lin', '')
|
||
|
new_key = new_key.replace('model.', '')
|
||
|
new_state_dict[new_key] = val
|
||
|
|
||
|
return new_state_dict
|