mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-06-26 18:18:11 +00:00
loss is timed on GPU
This commit is contained in:
parent
db60836c44
commit
4102a63f73
9
train.py
9
train.py
@ -65,9 +65,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
render_start = torch.cuda.Event(enable_timing = True)
|
render_start = torch.cuda.Event(enable_timing = True)
|
||||||
render_end = torch.cuda.Event(enable_timing = True)
|
render_end = torch.cuda.Event(enable_timing = True)
|
||||||
|
|
||||||
bp_start = torch.cuda.Event(enable_timing = True)
|
|
||||||
bp_end = torch.cuda.Event(enable_timing = True)
|
|
||||||
|
|
||||||
use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE
|
use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE
|
||||||
depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)
|
depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)
|
||||||
|
|
||||||
@ -130,7 +127,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
|
|
||||||
|
|
||||||
gt_image = viewpoint_cam.original_image.cuda()
|
gt_image = viewpoint_cam.original_image.cuda()
|
||||||
Ll1 = l1_loss(image, gt_image)
|
Ll1 = l1_loss(image, gt_image)
|
||||||
if FUSED_SSIM_AVAILABLE:
|
if FUSED_SSIM_AVAILABLE:
|
||||||
@ -156,7 +152,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
|
|
||||||
bp_start.record()
|
bp_start.record()
|
||||||
|
|
||||||
loss.backward()
|
_, _, _, _, bp_time = loss.backward()
|
||||||
|
|
||||||
bp_end.record()
|
bp_end.record()
|
||||||
|
|
||||||
@ -177,8 +173,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
|
|
||||||
iter_elapsed = iter_start.elapsed_time(iter_end)
|
iter_elapsed = iter_start.elapsed_time(iter_end)
|
||||||
render_elapsed = render_start.elapsed_time(render_end)
|
render_elapsed = render_start.elapsed_time(render_end)
|
||||||
bp_elapsed = bp_start.elapsed_time(bp_end)
|
bp_elapsed = bp_time
|
||||||
|
|
||||||
k1_elapsed = render_pkg["k1_time"]
|
k1_elapsed = render_pkg["k1_time"]
|
||||||
k2_elapsed = render_pkg["k2_time"]
|
k2_elapsed = render_pkg["k2_time"]
|
||||||
|
|
||||||
|
@ -34,8 +34,8 @@ class FusedSSIMMap(torch.autograd.Function):
|
|||||||
def backward(ctx, opt_grad):
|
def backward(ctx, opt_grad):
|
||||||
img1, img2 = ctx.saved_tensors
|
img1, img2 = ctx.saved_tensors
|
||||||
C1, C2 = ctx.C1, ctx.C2
|
C1, C2 = ctx.C1, ctx.C2
|
||||||
grad = fusedssim_backward(C1, C2, img1, img2, opt_grad)
|
grad, bp_time = fusedssim_backward(C1, C2, img1, img2, opt_grad)
|
||||||
return None, None, grad, None
|
return None, None, grad, None, bp_time
|
||||||
|
|
||||||
def l1_loss(network_output, gt):
|
def l1_loss(network_output, gt):
|
||||||
return torch.abs((network_output - gt)).mean()
|
return torch.abs((network_output - gt)).mean()
|
||||||
|
Loading…
Reference in New Issue
Block a user