Added checkpoints

This commit is contained in:
bkerbl 2023-07-14 21:09:46 +02:00
parent ef918fc9e3
commit 6b54263364
4 changed files with 80 additions and 31 deletions

View File

@ -72,7 +72,7 @@ The optimizer uses PyTorch and CUDA extensions in a Python environment to produc
### Software Requirements ### Software Requirements
- Conda (recommended for easy setup) - Conda (recommended for easy setup)
- C++ Compiler for PyTorch extensions (we used Visual Studio 2019 for Windows) - C++ Compiler for PyTorch extensions (we used Visual Studio 2019 for Windows)
- CUDA SDK 11.7+ for PyTorch extensions (we used 11.8, **known issues with 11.6**) - CUDA SDK 11 for PyTorch extensions (we used 11.8, **known issues with 11.6**)
- C++ Compiler and CUDA SDK must be compatible - C++ Compiler and CUDA SDK must be compatible
### Setup ### Setup
@ -83,6 +83,7 @@ SET DISTUTILS_USE_SDK=1 # Windows only
conda env create --file environment.yml conda env create --file environment.yml
conda activate gaussian_splatting conda activate gaussian_splatting
``` ```
Please note that this process assumes that you have CUDA SDK **11** installed, not **12**. For modifications, see below.
Tip: Downloading packages and creating a new environment with Conda can require a significant amount of disk space. By default, Conda will use the main system hard drive. You can avoid this by specifying a different package download location and an environment on a different drive: Tip: Downloading packages and creating a new environment with Conda can require a significant amount of disk space. By default, Conda will use the main system hard drive. You can avoid this by specifying a different package download location and an environment on a different drive:
@ -92,7 +93,9 @@ conda env create --file environment.yml --prefix <Drive>/<env_path>/gaussian_spl
conda activate <Drive>/<env_path>/gaussian_splatting conda activate <Drive>/<env_path>/gaussian_splatting
``` ```
If you can afford the disk space, we recommend using our environment files for setting up a training environment identical to ours. If you want to make modifications, please note that major version changes might affect the results of our method. However, our (limited) experiments suggest that the codebase works just fine inside a more up-to-date environment (Python 3.8, PyTorch 2.0.0, CUDA 11.8). #### Modifications
If you can afford the disk space, we recommend using our environment files for setting up a training environment identical to ours. If you want to make modifications, please note that major version changes might affect the results of our method. However, our (limited) experiments suggest that the codebase works just fine inside a more up-to-date environment (Python 3.8, PyTorch 2.0.0, CUDA 12). Make sure to create an environment where PyTorch and its CUDA runtime version match and the installed CUDA SDK has no major version difference with PyTorch's CUDA version.
### Running ### Running

View File

@ -78,8 +78,7 @@ class Scene:
self.gaussians.load_ply(os.path.join(self.model_path, self.gaussians.load_ply(os.path.join(self.model_path,
"point_cloud", "point_cloud",
"iteration_" + str(self.loaded_iter), "iteration_" + str(self.loaded_iter),
"point_cloud.ply"), "point_cloud.ply"))
og_number_points=len(scene_info.point_cloud.points))
else: else:
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)

View File

@ -22,28 +22,14 @@ from utils.graphics_utils import BasicPointCloud
from utils.general_utils import strip_symmetric, build_scaling_rotation from utils.general_utils import strip_symmetric, build_scaling_rotation
class GaussianModel: class GaussianModel:
def __init__(self, sh_degree : int):
def setup_functions(self):
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation) L = build_scaling_rotation(scaling_modifier * scaling, rotation)
actual_covariance = L @ L.transpose(1, 2) actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance) symm = strip_symmetric(actual_covariance)
return symm return symm
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
self._features_dc = torch.empty(0)
self._features_rest = torch.empty(0)
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self.max_radii2D = torch.empty(0)
self.xyz_gradient_accum = torch.empty(0)
self.optimizer = None
self.scaling_activation = torch.exp self.scaling_activation = torch.exp
self.scaling_inverse_activation = torch.log self.scaling_inverse_activation = torch.log
@ -54,6 +40,58 @@ class GaussianModel:
self.rotation_activation = torch.nn.functional.normalize self.rotation_activation = torch.nn.functional.normalize
def __init__(self, sh_degree : int):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
self._features_dc = torch.empty(0)
self._features_rest = torch.empty(0)
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self.max_radii2D = torch.empty(0)
self.xyz_gradient_accum = torch.empty(0)
self.denom = torch.empty(0)
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self.setup_functions()
def capture(self):
return (
self.active_sh_degree,
self._xyz,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self.max_radii2D,
self.xyz_gradient_accum,
self.denom,
self.optimizer.state_dict(),
self.spatial_lr_scale,
)
def restore(self, model_args, training_args):
(self.active_sh_degree,
self._xyz,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self.max_radii2D,
xyz_gradient_accum,
denom,
opt_dict,
self.spatial_lr_scale) = model_args
self.training_setup(training_args)
self.xyz_gradient_accum = xyz_gradient_accum
self.denom = denom
self.optimizer.load_state_dict(opt_dict)
@property @property
def get_scaling(self): def get_scaling(self):
return self.scaling_activation(self._scaling) return self.scaling_activation(self._scaling)
@ -174,8 +212,7 @@ class GaussianModel:
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
self._opacity = optimizable_tensors["opacity"] self._opacity = optimizable_tensors["opacity"]
def load_ply(self, path, og_number_points=-1): def load_ply(self, path):
self.og_number_points = og_number_points
plydata = PlyData.read(path) plydata = PlyData.read(path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]), xyz = np.stack((np.asarray(plydata.elements[0]["x"]),

View File

@ -28,12 +28,15 @@ try:
except ImportError: except ImportError:
TENSORBOARD_FOUND = False TENSORBOARD_FOUND = False
def training(dataset, opt, pipe, testing_iterations, saving_iterations): def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
first_iter = 0
tb_writer = prepare_output_and_logger(dataset) tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree) gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians) scene = Scene(dataset, gaussians)
gaussians.training_setup(opt) gaussians.training_setup(opt)
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt)
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
@ -43,8 +46,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
viewpoint_stack = None viewpoint_stack = None
ema_loss_for_log = 0.0 ema_loss_for_log = 0.0
progress_bar = tqdm(range(opt.iterations), desc="Training progress") progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
for iteration in range(1, opt.iterations + 1): first_iter += 1
for iteration in range(first_iter, opt.iterations + 1):
if network_gui.conn == None: if network_gui.conn == None:
network_gui.try_connect() network_gui.try_connect()
while network_gui.conn != None: while network_gui.conn != None:
@ -62,6 +66,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
iter_start.record() iter_start.record()
gaussians.update_learning_rate(iteration)
# Every 1000 its we increase the levels of SH up to a maximum degree # Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0: if iteration % 1000 == 0:
gaussians.oneupSHdegree() gaussians.oneupSHdegree()
@ -92,9 +98,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
if iteration == opt.iterations: if iteration == opt.iterations:
progress_bar.close() progress_bar.close()
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
# Log and save # Log and save
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
if (iteration in saving_iterations): if (iteration in saving_iterations):
@ -103,6 +106,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
# Densification # Densification
if iteration < opt.densify_until_iter: if iteration < opt.densify_until_iter:
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
@ -116,7 +121,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
if iteration < opt.iterations: if iteration < opt.iterations:
gaussians.optimizer.step() gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True) gaussians.optimizer.zero_grad(set_to_none = True)
gaussians.update_learning_rate(iteration)
if (iteration in checkpoint_iterations):
print("\n[ITER {}] Saving Checkpoint".format(iteration))
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
def prepare_output_and_logger(args): def prepare_output_and_logger(args):
if not args.model_path: if not args.model_path:
@ -189,6 +197,8 @@ if __name__ == "__main__":
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--quiet", action="store_true") parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations) args.save_iterations.append(args.iterations)
@ -200,7 +210,7 @@ if __name__ == "__main__":
# Start GUI server, configure and run training # Start GUI server, configure and run training
network_gui.init(args.ip, args.port) network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly) torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations) training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
# All done # All done
print("\nTraining complete.") print("\nTraining complete.")