mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-25 13:26:47 +00:00
Added checkpoints
This commit is contained in:
parent
ef918fc9e3
commit
6b54263364
@ -72,7 +72,7 @@ The optimizer uses PyTorch and CUDA extensions in a Python environment to produc
|
||||
### Software Requirements
|
||||
- Conda (recommended for easy setup)
|
||||
- C++ Compiler for PyTorch extensions (we used Visual Studio 2019 for Windows)
|
||||
- CUDA SDK 11.7+ for PyTorch extensions (we used 11.8, **known issues with 11.6**)
|
||||
- CUDA SDK 11 for PyTorch extensions (we used 11.8, **known issues with 11.6**)
|
||||
- C++ Compiler and CUDA SDK must be compatible
|
||||
|
||||
### Setup
|
||||
@ -83,6 +83,7 @@ SET DISTUTILS_USE_SDK=1 # Windows only
|
||||
conda env create --file environment.yml
|
||||
conda activate gaussian_splatting
|
||||
```
|
||||
Please note that this process assumes that you have CUDA SDK **11** installed, not **12**. For modifications, see below.
|
||||
|
||||
Tip: Downloading packages and creating a new environment with Conda can require a significant amount of disk space. By default, Conda will use the main system hard drive. You can avoid this by specifying a different package download location and an environment on a different drive:
|
||||
|
||||
@ -92,7 +93,9 @@ conda env create --file environment.yml --prefix <Drive>/<env_path>/gaussian_spl
|
||||
conda activate <Drive>/<env_path>/gaussian_splatting
|
||||
```
|
||||
|
||||
If you can afford the disk space, we recommend using our environment files for setting up a training environment identical to ours. If you want to make modifications, please note that major version changes might affect the results of our method. However, our (limited) experiments suggest that the codebase works just fine inside a more up-to-date environment (Python 3.8, PyTorch 2.0.0, CUDA 11.8).
|
||||
#### Modifications
|
||||
|
||||
If you can afford the disk space, we recommend using our environment files for setting up a training environment identical to ours. If you want to make modifications, please note that major version changes might affect the results of our method. However, our (limited) experiments suggest that the codebase works just fine inside a more up-to-date environment (Python 3.8, PyTorch 2.0.0, CUDA 12). Make sure to create an environment where PyTorch and its CUDA runtime version match and the installed CUDA SDK has no major version difference with PyTorch's CUDA version.
|
||||
|
||||
### Running
|
||||
|
||||
|
@ -78,8 +78,7 @@ class Scene:
|
||||
self.gaussians.load_ply(os.path.join(self.model_path,
|
||||
"point_cloud",
|
||||
"iteration_" + str(self.loaded_iter),
|
||||
"point_cloud.ply"),
|
||||
og_number_points=len(scene_info.point_cloud.points))
|
||||
"point_cloud.ply"))
|
||||
else:
|
||||
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
|
||||
|
||||
|
@ -22,28 +22,14 @@ from utils.graphics_utils import BasicPointCloud
|
||||
from utils.general_utils import strip_symmetric, build_scaling_rotation
|
||||
|
||||
class GaussianModel:
|
||||
def __init__(self, sh_degree : int):
|
||||
|
||||
def setup_functions(self):
|
||||
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
||||
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
||||
actual_covariance = L @ L.transpose(1, 2)
|
||||
symm = strip_symmetric(actual_covariance)
|
||||
return symm
|
||||
|
||||
self.active_sh_degree = 0
|
||||
self.max_sh_degree = sh_degree
|
||||
|
||||
self._xyz = torch.empty(0)
|
||||
self._features_dc = torch.empty(0)
|
||||
self._features_rest = torch.empty(0)
|
||||
self._scaling = torch.empty(0)
|
||||
self._rotation = torch.empty(0)
|
||||
self._opacity = torch.empty(0)
|
||||
self.max_radii2D = torch.empty(0)
|
||||
self.xyz_gradient_accum = torch.empty(0)
|
||||
|
||||
self.optimizer = None
|
||||
|
||||
self.scaling_activation = torch.exp
|
||||
self.scaling_inverse_activation = torch.log
|
||||
|
||||
@ -54,6 +40,58 @@ class GaussianModel:
|
||||
|
||||
self.rotation_activation = torch.nn.functional.normalize
|
||||
|
||||
|
||||
def __init__(self, sh_degree : int):
|
||||
self.active_sh_degree = 0
|
||||
self.max_sh_degree = sh_degree
|
||||
self._xyz = torch.empty(0)
|
||||
self._features_dc = torch.empty(0)
|
||||
self._features_rest = torch.empty(0)
|
||||
self._scaling = torch.empty(0)
|
||||
self._rotation = torch.empty(0)
|
||||
self._opacity = torch.empty(0)
|
||||
self.max_radii2D = torch.empty(0)
|
||||
self.xyz_gradient_accum = torch.empty(0)
|
||||
self.denom = torch.empty(0)
|
||||
self.optimizer = None
|
||||
self.percent_dense = 0
|
||||
self.spatial_lr_scale = 0
|
||||
self.setup_functions()
|
||||
|
||||
def capture(self):
|
||||
return (
|
||||
self.active_sh_degree,
|
||||
self._xyz,
|
||||
self._features_dc,
|
||||
self._features_rest,
|
||||
self._scaling,
|
||||
self._rotation,
|
||||
self._opacity,
|
||||
self.max_radii2D,
|
||||
self.xyz_gradient_accum,
|
||||
self.denom,
|
||||
self.optimizer.state_dict(),
|
||||
self.spatial_lr_scale,
|
||||
)
|
||||
|
||||
def restore(self, model_args, training_args):
|
||||
(self.active_sh_degree,
|
||||
self._xyz,
|
||||
self._features_dc,
|
||||
self._features_rest,
|
||||
self._scaling,
|
||||
self._rotation,
|
||||
self._opacity,
|
||||
self.max_radii2D,
|
||||
xyz_gradient_accum,
|
||||
denom,
|
||||
opt_dict,
|
||||
self.spatial_lr_scale) = model_args
|
||||
self.training_setup(training_args)
|
||||
self.xyz_gradient_accum = xyz_gradient_accum
|
||||
self.denom = denom
|
||||
self.optimizer.load_state_dict(opt_dict)
|
||||
|
||||
@property
|
||||
def get_scaling(self):
|
||||
return self.scaling_activation(self._scaling)
|
||||
@ -174,8 +212,7 @@ class GaussianModel:
|
||||
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
||||
self._opacity = optimizable_tensors["opacity"]
|
||||
|
||||
def load_ply(self, path, og_number_points=-1):
|
||||
self.og_number_points = og_number_points
|
||||
def load_ply(self, path):
|
||||
plydata = PlyData.read(path)
|
||||
|
||||
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
||||
|
28
train.py
28
train.py
@ -28,12 +28,15 @@ try:
|
||||
except ImportError:
|
||||
TENSORBOARD_FOUND = False
|
||||
|
||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
|
||||
first_iter = 0
|
||||
tb_writer = prepare_output_and_logger(dataset)
|
||||
gaussians = GaussianModel(dataset.sh_degree)
|
||||
|
||||
scene = Scene(dataset, gaussians)
|
||||
gaussians.training_setup(opt)
|
||||
if checkpoint:
|
||||
(model_params, first_iter) = torch.load(checkpoint)
|
||||
gaussians.restore(model_params, opt)
|
||||
|
||||
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
||||
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
||||
@ -43,8 +46,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
||||
|
||||
viewpoint_stack = None
|
||||
ema_loss_for_log = 0.0
|
||||
progress_bar = tqdm(range(opt.iterations), desc="Training progress")
|
||||
for iteration in range(1, opt.iterations + 1):
|
||||
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
||||
first_iter += 1
|
||||
for iteration in range(first_iter, opt.iterations + 1):
|
||||
if network_gui.conn == None:
|
||||
network_gui.try_connect()
|
||||
while network_gui.conn != None:
|
||||
@ -62,6 +66,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
||||
|
||||
iter_start.record()
|
||||
|
||||
gaussians.update_learning_rate(iteration)
|
||||
|
||||
# Every 1000 its we increase the levels of SH up to a maximum degree
|
||||
if iteration % 1000 == 0:
|
||||
gaussians.oneupSHdegree()
|
||||
@ -92,9 +98,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
||||
if iteration == opt.iterations:
|
||||
progress_bar.close()
|
||||
|
||||
# Keep track of max radii in image-space for pruning
|
||||
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
||||
|
||||
# Log and save
|
||||
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
||||
if (iteration in saving_iterations):
|
||||
@ -103,6 +106,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
||||
|
||||
# Densification
|
||||
if iteration < opt.densify_until_iter:
|
||||
# Keep track of max radii in image-space for pruning
|
||||
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
||||
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
||||
|
||||
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
||||
@ -116,7 +121,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
||||
if iteration < opt.iterations:
|
||||
gaussians.optimizer.step()
|
||||
gaussians.optimizer.zero_grad(set_to_none = True)
|
||||
gaussians.update_learning_rate(iteration)
|
||||
|
||||
if (iteration in checkpoint_iterations):
|
||||
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
||||
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
||||
|
||||
def prepare_output_and_logger(args):
|
||||
if not args.model_path:
|
||||
@ -189,6 +197,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--quiet", action="store_true")
|
||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--start_checkpoint", type=str, default = None)
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
args.save_iterations.append(args.iterations)
|
||||
|
||||
@ -200,7 +210,7 @@ if __name__ == "__main__":
|
||||
# Start GUI server, configure and run training
|
||||
network_gui.init(args.ip, args.port)
|
||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
|
||||
|
||||
# All done
|
||||
print("\nTraining complete.")
|
||||
|
Loading…
Reference in New Issue
Block a user