diff --git a/train.py b/train.py index 085b689..baaa717 100644 --- a/train.py +++ b/train.py @@ -33,18 +33,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi start_time=time.time() first_iter = 0 tb_writer = prepare_output_and_logger(dataset) # Tensorboard writer - gaussians = GaussianModel(dataset.sh_degree) #高斯模型 - scene = Scene(dataset, gaussians) #场景 - gaussians.training_setup(opt) #训练设置 + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) if checkpoint: (model_params, first_iter) = torch.load(checkpoint) gaussians.restore(model_params, opt) - bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] #背景颜色 - background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") #背景颜色 + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") - iter_start = torch.cuda.Event(enable_timing = True) #开始时间 - iter_end = torch.cuda.Event(enable_timing = True) #结束时间 + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) viewpoint_stack = None ema_loss_for_log = 0.0 @@ -66,9 +66,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi except Exception as e: network_gui.conn = None - iter_start.record() #记录开始时间 + iter_start.record() - gaussians.update_learning_rate(iteration) #更新学习率 + gaussians.update_learning_rate(iteration) # Every 1000 its we increase the levels of SH up to a maximum degree if iteration % 1000 == 0: @@ -85,7 +85,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi bg = torch.rand((3), device="cuda") if opt.random_background else background - render_pkg = render(viewpoint_cam, gaussians, pipe, bg) #渲染 + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] # Loss @@ -141,9 +141,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi print("\n[ITER {}] Saving Checkpoint".format(iteration)) torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") - end_time = time.time() # 记录训练结束时间 - total_time = end_time - start_time # 计算总时间 - print(f"\nTraining complete. Total training time: {total_time:.2f} seconds.") # 打印总时间 + end_time = time.time() + total_time = end_time - start_time + print(f"\nTraining complete. Total training time: {total_time:.2f} seconds.") def prepare_output_and_logger(args):