fix gradient missing during densify&prune

This commit is contained in:
Dojizz 2024-05-28 17:11:57 +08:00
parent 472689c0dc
commit 9e54fc93e4

View File

@ -277,15 +277,14 @@ class GaussianModel:
if stored_state is not None:
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
self.optimizer.state[group['params'][0]] = stored_state
grad = group["params"][0].grad[mask].clone()
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
group["params"][0].grad = grad
optimizable_tensors[group["name"]] = group["params"][0]
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def prune_points(self, mask):
@ -314,15 +313,13 @@ class GaussianModel:
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
grad = torch.cat((group["params"][0].grad, torch.zeros_like(extension_tensor)), dim=0).clone()
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
group["params"][0].grad = grad
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
@ -389,7 +386,7 @@ class GaussianModel:
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
grads = self.xyz_gradient_accum / self.denom
grads[grads.isnan()] = 0.0
self.densify_and_clone(grads, max_grad, extent)
self.densify_and_split(grads, max_grad, extent)