mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 08:18:17 +00:00
fix gradient missing during densify&prune
This commit is contained in:
parent
472689c0dc
commit
9e54fc93e4
@ -277,15 +277,14 @@ class GaussianModel:
|
||||
if stored_state is not None:
|
||||
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
||||
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
||||
|
||||
del self.optimizer.state[group['params'][0]]
|
||||
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
||||
self.optimizer.state[group['params'][0]] = stored_state
|
||||
|
||||
optimizable_tensors[group["name"]] = group["params"][0]
|
||||
else:
|
||||
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
||||
optimizable_tensors[group["name"]] = group["params"][0]
|
||||
grad = group["params"][0].grad[mask].clone()
|
||||
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
||||
group["params"][0].grad = grad
|
||||
optimizable_tensors[group["name"]] = group["params"][0]
|
||||
|
||||
return optimizable_tensors
|
||||
|
||||
def prune_points(self, mask):
|
||||
@ -314,15 +313,13 @@ class GaussianModel:
|
||||
|
||||
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
||||
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
||||
|
||||
del self.optimizer.state[group['params'][0]]
|
||||
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
||||
self.optimizer.state[group['params'][0]] = stored_state
|
||||
|
||||
optimizable_tensors[group["name"]] = group["params"][0]
|
||||
else:
|
||||
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
||||
optimizable_tensors[group["name"]] = group["params"][0]
|
||||
grad = torch.cat((group["params"][0].grad, torch.zeros_like(extension_tensor)), dim=0).clone()
|
||||
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
||||
group["params"][0].grad = grad
|
||||
optimizable_tensors[group["name"]] = group["params"][0]
|
||||
|
||||
return optimizable_tensors
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user