diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index 632a1e8..bafeaef 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -277,15 +277,14 @@ class GaussianModel: if stored_state is not None: stored_state["exp_avg"] = stored_state["exp_avg"][mask] stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] - del self.optimizer.state[group['params'][0]] - group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) self.optimizer.state[group['params'][0]] = stored_state + + grad = group["params"][0].grad[mask].clone() + group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) + group["params"][0].grad = grad + optimizable_tensors[group["name"]] = group["params"][0] - optimizable_tensors[group["name"]] = group["params"][0] - else: - group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) - optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors def prune_points(self, mask): @@ -314,15 +313,13 @@ class GaussianModel: stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) - del self.optimizer.state[group['params'][0]] - group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) self.optimizer.state[group['params'][0]] = stored_state - optimizable_tensors[group["name"]] = group["params"][0] - else: - group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) - optimizable_tensors[group["name"]] = group["params"][0] + grad = torch.cat((group["params"][0].grad, torch.zeros_like(extension_tensor)), dim=0).clone() + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + group["params"][0].grad = grad + optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors @@ -389,7 +386,7 @@ class GaussianModel: def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): grads = self.xyz_gradient_accum / self.denom grads[grads.isnan()] = 0.0 - + self.densify_and_clone(grads, max_grad, extent) self.densify_and_split(grads, max_grad, extent)