mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-25 05:16:33 +00:00
Update train.py
Updated the main function, call the create_window() function for just once, since the window_size is fixed to 11 and the channel is fixed to 3 according to the source code. Updated the calculation of loss function, call the ssim_optimized(), instead of the original ssim().
This commit is contained in:
parent
b66e1ad13e
commit
25b71e16d9
26
train.py
26
train.py
@ -33,18 +33,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
start_time=time.time()
|
start_time=time.time()
|
||||||
first_iter = 0
|
first_iter = 0
|
||||||
tb_writer = prepare_output_and_logger(dataset) # Tensorboard writer
|
tb_writer = prepare_output_and_logger(dataset) # Tensorboard writer
|
||||||
gaussians = GaussianModel(dataset.sh_degree) #高斯模型
|
gaussians = GaussianModel(dataset.sh_degree)
|
||||||
scene = Scene(dataset, gaussians) #场景
|
scene = Scene(dataset, gaussians)
|
||||||
gaussians.training_setup(opt) #训练设置
|
gaussians.training_setup(opt)
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
(model_params, first_iter) = torch.load(checkpoint)
|
(model_params, first_iter) = torch.load(checkpoint)
|
||||||
gaussians.restore(model_params, opt)
|
gaussians.restore(model_params, opt)
|
||||||
|
|
||||||
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] #背景颜色
|
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
||||||
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") #背景颜色
|
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
iter_start = torch.cuda.Event(enable_timing = True) #开始时间
|
iter_start = torch.cuda.Event(enable_timing = True)
|
||||||
iter_end = torch.cuda.Event(enable_timing = True) #结束时间
|
iter_end = torch.cuda.Event(enable_timing = True)
|
||||||
|
|
||||||
viewpoint_stack = None
|
viewpoint_stack = None
|
||||||
ema_loss_for_log = 0.0
|
ema_loss_for_log = 0.0
|
||||||
@ -66,9 +66,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
network_gui.conn = None
|
network_gui.conn = None
|
||||||
|
|
||||||
iter_start.record() #记录开始时间
|
iter_start.record()
|
||||||
|
|
||||||
gaussians.update_learning_rate(iteration) #更新学习率
|
gaussians.update_learning_rate(iteration)
|
||||||
|
|
||||||
# Every 1000 its we increase the levels of SH up to a maximum degree
|
# Every 1000 its we increase the levels of SH up to a maximum degree
|
||||||
if iteration % 1000 == 0:
|
if iteration % 1000 == 0:
|
||||||
@ -85,7 +85,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
|
|
||||||
bg = torch.rand((3), device="cuda") if opt.random_background else background
|
bg = torch.rand((3), device="cuda") if opt.random_background else background
|
||||||
|
|
||||||
render_pkg = render(viewpoint_cam, gaussians, pipe, bg) #渲染
|
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
|
||||||
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
@ -141,9 +141,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
||||||
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
||||||
|
|
||||||
end_time = time.time() # 记录训练结束时间
|
end_time = time.time()
|
||||||
total_time = end_time - start_time # 计算总时间
|
total_time = end_time - start_time
|
||||||
print(f"\nTraining complete. Total training time: {total_time:.2f} seconds.") # 打印总时间
|
print(f"\nTraining complete. Total training time: {total_time:.2f} seconds.")
|
||||||
|
|
||||||
|
|
||||||
def prepare_output_and_logger(args):
|
def prepare_output_and_logger(args):
|
||||||
|
Loading…
Reference in New Issue
Block a user