fix gradient missing during densify&prune

This commit is contained in:
Dojizz 2024-05-28 17:11:57 +08:00
parent 472689c0dc
commit 9e54fc93e4

View File

@ -277,15 +277,14 @@ class GaussianModel:
if stored_state is not None:
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
grad = group["params"][0].grad[mask].clone()
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
group["params"][0].grad = grad
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def prune_points(self, mask):
@ -314,14 +313,12 @@ class GaussianModel:
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
grad = torch.cat((group["params"][0].grad, torch.zeros_like(extension_tensor)), dim=0).clone()
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
group["params"][0].grad = grad
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors