mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-06-26 18:18:11 +00:00
loss is timed on GPU
This commit is contained in:
parent
db60836c44
commit
4102a63f73
9
train.py
9
train.py
@ -65,9 +65,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
render_start = torch.cuda.Event(enable_timing = True)
|
||||
render_end = torch.cuda.Event(enable_timing = True)
|
||||
|
||||
bp_start = torch.cuda.Event(enable_timing = True)
|
||||
bp_end = torch.cuda.Event(enable_timing = True)
|
||||
|
||||
use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE
|
||||
depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)
|
||||
|
||||
@ -130,7 +127,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
|
||||
# Loss
|
||||
|
||||
|
||||
gt_image = viewpoint_cam.original_image.cuda()
|
||||
Ll1 = l1_loss(image, gt_image)
|
||||
if FUSED_SSIM_AVAILABLE:
|
||||
@ -156,7 +152,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
|
||||
bp_start.record()
|
||||
|
||||
loss.backward()
|
||||
_, _, _, _, bp_time = loss.backward()
|
||||
|
||||
bp_end.record()
|
||||
|
||||
@ -177,8 +173,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
|
||||
iter_elapsed = iter_start.elapsed_time(iter_end)
|
||||
render_elapsed = render_start.elapsed_time(render_end)
|
||||
bp_elapsed = bp_start.elapsed_time(bp_end)
|
||||
|
||||
bp_elapsed = bp_time
|
||||
k1_elapsed = render_pkg["k1_time"]
|
||||
k2_elapsed = render_pkg["k2_time"]
|
||||
|
||||
|
@ -34,8 +34,8 @@ class FusedSSIMMap(torch.autograd.Function):
|
||||
def backward(ctx, opt_grad):
|
||||
img1, img2 = ctx.saved_tensors
|
||||
C1, C2 = ctx.C1, ctx.C2
|
||||
grad = fusedssim_backward(C1, C2, img1, img2, opt_grad)
|
||||
return None, None, grad, None
|
||||
grad, bp_time = fusedssim_backward(C1, C2, img1, img2, opt_grad)
|
||||
return None, None, grad, None, bp_time
|
||||
|
||||
def l1_loss(network_output, gt):
|
||||
return torch.abs((network_output - gt)).mean()
|
||||
|
Loading…
Reference in New Issue
Block a user