loss is timed on GPU

This commit is contained in:
Tomas Dougan 2025-05-01 10:50:59 -04:00
parent db60836c44
commit 4102a63f73
2 changed files with 4 additions and 9 deletions

View File

@ -65,9 +65,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
render_start = torch.cuda.Event(enable_timing = True) render_start = torch.cuda.Event(enable_timing = True)
render_end = torch.cuda.Event(enable_timing = True) render_end = torch.cuda.Event(enable_timing = True)
bp_start = torch.cuda.Event(enable_timing = True)
bp_end = torch.cuda.Event(enable_timing = True)
use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE
depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations) depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)
@ -130,7 +127,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
# Loss # Loss
gt_image = viewpoint_cam.original_image.cuda() gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image) Ll1 = l1_loss(image, gt_image)
if FUSED_SSIM_AVAILABLE: if FUSED_SSIM_AVAILABLE:
@ -156,7 +152,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
bp_start.record() bp_start.record()
loss.backward() _, _, _, _, bp_time = loss.backward()
bp_end.record() bp_end.record()
@ -177,8 +173,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
iter_elapsed = iter_start.elapsed_time(iter_end) iter_elapsed = iter_start.elapsed_time(iter_end)
render_elapsed = render_start.elapsed_time(render_end) render_elapsed = render_start.elapsed_time(render_end)
bp_elapsed = bp_start.elapsed_time(bp_end) bp_elapsed = bp_time
k1_elapsed = render_pkg["k1_time"] k1_elapsed = render_pkg["k1_time"]
k2_elapsed = render_pkg["k2_time"] k2_elapsed = render_pkg["k2_time"]

View File

@ -34,8 +34,8 @@ class FusedSSIMMap(torch.autograd.Function):
def backward(ctx, opt_grad): def backward(ctx, opt_grad):
img1, img2 = ctx.saved_tensors img1, img2 = ctx.saved_tensors
C1, C2 = ctx.C1, ctx.C2 C1, C2 = ctx.C1, ctx.C2
grad = fusedssim_backward(C1, C2, img1, img2, opt_grad) grad, bp_time = fusedssim_backward(C1, C2, img1, img2, opt_grad)
return None, None, grad, None return None, None, grad, None, bp_time
def l1_loss(network_output, gt): def l1_loss(network_output, gt):
return torch.abs((network_output - gt)).mean() return torch.abs((network_output - gt)).mean()