mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 08:18:17 +00:00
fix gradient missing during densify&prune
This commit is contained in:
parent
472689c0dc
commit
9e54fc93e4
@ -277,15 +277,14 @@ class GaussianModel:
|
|||||||
if stored_state is not None:
|
if stored_state is not None:
|
||||||
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
||||||
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
||||||
|
|
||||||
del self.optimizer.state[group['params'][0]]
|
del self.optimizer.state[group['params'][0]]
|
||||||
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
|
||||||
self.optimizer.state[group['params'][0]] = stored_state
|
self.optimizer.state[group['params'][0]] = stored_state
|
||||||
|
|
||||||
optimizable_tensors[group["name"]] = group["params"][0]
|
grad = group["params"][0].grad[mask].clone()
|
||||||
else:
|
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
||||||
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
group["params"][0].grad = grad
|
||||||
optimizable_tensors[group["name"]] = group["params"][0]
|
optimizable_tensors[group["name"]] = group["params"][0]
|
||||||
|
|
||||||
return optimizable_tensors
|
return optimizable_tensors
|
||||||
|
|
||||||
def prune_points(self, mask):
|
def prune_points(self, mask):
|
||||||
@ -314,15 +313,13 @@ class GaussianModel:
|
|||||||
|
|
||||||
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
||||||
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
||||||
|
|
||||||
del self.optimizer.state[group['params'][0]]
|
del self.optimizer.state[group['params'][0]]
|
||||||
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
|
||||||
self.optimizer.state[group['params'][0]] = stored_state
|
self.optimizer.state[group['params'][0]] = stored_state
|
||||||
|
|
||||||
optimizable_tensors[group["name"]] = group["params"][0]
|
grad = torch.cat((group["params"][0].grad, torch.zeros_like(extension_tensor)), dim=0).clone()
|
||||||
else:
|
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
||||||
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
group["params"][0].grad = grad
|
||||||
optimizable_tensors[group["name"]] = group["params"][0]
|
optimizable_tensors[group["name"]] = group["params"][0]
|
||||||
|
|
||||||
return optimizable_tensors
|
return optimizable_tensors
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user